Shatei commited on
Commit
5e373a9
·
1 Parent(s): 773164c

Update space

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +162 -0
  2. CODE_OF_CONDUCT.md +79 -0
  3. CONTRIBUTING.md +47 -0
  4. LICENSE.md +13 -0
  5. README.md +70 -13
  6. app.py +48 -141
  7. configs/collage_composite_train.yaml +114 -0
  8. configs/collage_flow_train.yaml +114 -0
  9. configs/collage_mix_train.yaml +115 -0
  10. data_processing/example_videos/getty-soccer-ball-jordan-video-id473239807_26.mp4 +0 -0
  11. data_processing/example_videos/getty-video-of-american-flags-being-sewn-together-at-flagsource-in-batavia-video-id804937470_87.mp4 +0 -0
  12. data_processing/example_videos/giphy-fgiT2cbsTxl8k_0.mp4 +0 -0
  13. data_processing/example_videos/giphy-gkvCpHRX9IqkM_3.mp4 +0 -0
  14. data_processing/example_videos/yt--4Fx5XUD-9Y_345.mp4 +0 -0
  15. data_processing/example_videos/yt-mNdvtOO7UqY_15.mp4 +0 -0
  16. data_processing/moments_dataset.py +54 -0
  17. data_processing/moments_processing.py +359 -0
  18. data_processing/processing_utils.py +304 -0
  19. environment.yaml +33 -0
  20. examples/dog_beach__edit__003.png +0 -0
  21. examples/dog_beach_og.png +0 -0
  22. examples/fox_drinking__edit__01.png +0 -0
  23. examples/fox_drinking__edit__02.png +0 -0
  24. examples/fox_drinking_og.png +0 -0
  25. examples/kingfisher__edit__001.png +0 -0
  26. examples/kingfisher_og.png +0 -0
  27. examples/log.csv +6 -0
  28. examples/palm_tree__edit__01.png +0 -0
  29. examples/palm_tree_og.png +0 -0
  30. examples/pipes__edit__01.png +0 -0
  31. examples/pipes_og.png +0 -0
  32. ku.py +0 -1
  33. ldm/data/__init__.py +0 -0
  34. ldm/data/collage_dataset.py +230 -0
  35. ldm/lr_scheduler.py +111 -0
  36. ldm/models/autoencoder.py +456 -0
  37. ldm/models/diffusion/__init__.py +0 -0
  38. ldm/models/diffusion/classifier.py +280 -0
  39. ldm/models/diffusion/ddim.py +296 -0
  40. ldm/models/diffusion/ddpm.py +1877 -0
  41. ldm/models/diffusion/plms.py +251 -0
  42. ldm/modules/attention.py +372 -0
  43. ldm/modules/diffusionmodules/__init__.py +0 -0
  44. ldm/modules/diffusionmodules/model.py +848 -0
  45. ldm/modules/diffusionmodules/openaimodel.py +1225 -0
  46. ldm/modules/diffusionmodules/util.py +285 -0
  47. ldm/modules/distributions/__init__.py +0 -0
  48. ldm/modules/distributions/distributions.py +105 -0
  49. ldm/modules/ema.py +89 -0
  50. ldm/modules/encoders/__init__.py +0 -0
.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adobe Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ We as members, contributors, and leaders pledge to make participation in our project and community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, caste, color, religion, or sexual identity and orientation.
6
+
7
+ We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community.
8
+
9
+ ## Our Standards
10
+
11
+ Examples of behavior that contribute to a positive environment for our project and community include:
12
+
13
+ * Demonstrating empathy and kindness toward other people
14
+ * Being respectful of differing opinions, viewpoints, and experiences
15
+ * Giving and gracefully accepting constructive feedback
16
+ * Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience
17
+ * Focusing on what is best, not just for us as individuals but for the overall community
18
+
19
+ Examples of unacceptable behavior include:
20
+
21
+ * The use of sexualized language or imagery, and sexual attention or advances of any kind
22
+ * Trolling, insulting or derogatory comments, and personal or political attacks
23
+ * Public or private harassment
24
+ * Publishing others’ private information, such as a physical or email address, without their explicit permission
25
+ * Other conduct which could reasonably be considered inappropriate in a professional setting
26
+
27
+ ## Our Responsibilities
28
+
29
+ Project maintainers are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any instances of unacceptable behavior.
30
+
31
+ Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for behaviors that they deem inappropriate, threatening, offensive, or harmful.
32
+
33
+ ## Scope
34
+
35
+ This Code of Conduct applies when an individual is representing the project or its community both within project spaces and in public spaces. Examples of representing a project or community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers.
36
+
37
+ ## Enforcement
38
+
39
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by first contacting the project team. Oversight of Adobe projects is handled by the Adobe Open Source Office, which has final say in any violations and enforcement of this Code of Conduct and can be reached at [email protected]. All complaints will be reviewed and investigated promptly and fairly.
40
+
41
+ The project team must respect the privacy and security of the reporter of any incident.
42
+
43
+ Project maintainers who do not follow or enforce the Code of Conduct may face temporary or permanent repercussions as determined by other members of the project's leadership or the Adobe Open Source Office.
44
+
45
+ ## Enforcement Guidelines
46
+
47
+ Project maintainers will follow these Community Impact Guidelines in determining the consequences for any action they deem to be in violation of this Code of Conduct:
48
+
49
+ **1. Correction**
50
+
51
+ Community Impact: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community.
52
+
53
+ Consequence: A private, written warning from project maintainers describing the violation and why the behavior was unacceptable. A public apology may be requested from the violator before any further involvement in the project by violator.
54
+
55
+ **2. Warning**
56
+
57
+ Community Impact: A relatively minor violation through a single incident or series of actions.
58
+
59
+ Consequence: A written warning from project maintainers that includes stated consequences for continued unacceptable behavior. Violator must refrain from interacting with the people involved for a specified period of time as determined by the project maintainers, including, but not limited to, unsolicited interaction with those enforcing the Code of Conduct through channels such as community spaces and social media. Continued violations may lead to a temporary or permanent ban.
60
+
61
+ **3. Temporary Ban**
62
+
63
+ Community Impact: A more serious violation of community standards, including sustained unacceptable behavior.
64
+
65
+ Consequence: A temporary ban from any interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Failure to comply with the temporary ban may lead to a permanent ban.
66
+
67
+ **4. Permanent Ban**
68
+
69
+ Community Impact: Demonstrating a consistent pattern of violation of community standards or an egregious violation of community standards, including, but not limited to, sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals.
70
+
71
+ Consequence: A permanent ban from any interaction with the community.
72
+
73
+ ## Attribution
74
+
75
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.1,
76
+ available at [https://contributor-covenant.org/version/2/1][version]
77
+
78
+ [homepage]: https://contributor-covenant.org
79
+ [version]: https://contributor-covenant.org/version/2/1
CONTRIBUTING.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing
2
+
3
+ Thanks for choosing to contribute!
4
+
5
+ The following are a set of guidelines to follow when contributing to this project.
6
+
7
+ ## Code Of Conduct
8
+
9
+ This project adheres to the Adobe [code of conduct](./CODE_OF_CONDUCT.md). By participating,
10
+ you are expected to uphold this code. Please report unacceptable behavior to
11
12
+
13
+ ## Have A Question?
14
+
15
+ Start by filing an issue. The existing committers on this project work to reach
16
+ consensus around project direction and issue solutions within issue threads
17
+ (when appropriate).
18
+
19
+ ## Contributor License Agreement
20
+
21
+ All third-party contributions to this project must be accompanied by a signed contributor
22
+ license agreement. This gives Adobe permission to redistribute your contributions
23
+ as part of the project. [Sign our CLA](https://opensource.adobe.com/cla.html). You
24
+ only need to submit an Adobe CLA one time, so if you have submitted one previously,
25
+ you are good to go!
26
+
27
+ ## Code Reviews
28
+
29
+ All submissions should come in the form of pull requests and need to be reviewed
30
+ by project committers. Read [GitHub's pull request documentation](https://help.github.com/articles/about-pull-requests/)
31
+ for more information on sending pull requests.
32
+
33
+ Lastly, please follow the [pull request template](PULL_REQUEST_TEMPLATE.md) when
34
+ submitting a pull request!
35
+
36
+ ## From Contributor To Committer
37
+
38
+ We love contributions from our community! If you'd like to go a step beyond contributor
39
+ and become a committer with full write access and a say in the project, you must
40
+ be invited to the project. The existing committers employ an internal nomination
41
+ process that must reach lazy consensus (silence is approval) before invitations
42
+ are issued. If you feel you are qualified and want to get more deeply involved,
43
+ feel free to reach out to existing committers to have a conversation about that.
44
+
45
+ ## Security Issues
46
+
47
+ Security issues shouldn't be reported on this issue tracker. Instead, [file an issue to our security experts](https://helpx.adobe.com/security/alertus.html).
LICENSE.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright 2024, Adobe Inc. and its licensors. All rights reserved.
2
+
3
+ ADOBE RESEARCH LICENSE
4
+
5
+ Adobe grants any person or entity ("you" or "your") obtaining a copy of these certain research materials that are owned by Adobe ("Licensed Materials") a nonexclusive, worldwide, royalty-free, revocable, fully paid license to (A) reproduce, use, modify, and publicly display the Licensed Materials; and (B) redistribute the Licensed Materials, and modifications or derivative works thereof, provided the following conditions are met:
6
+
7
+ The rights granted herein may be exercised for noncommercial research purposes (i.e., academic research and teaching) only. Noncommercial research purposes do not include commercial licensing or distribution, development of commercial products, or any other activity that results in commercial gain.
8
+ You may add your own copyright statement to your modifications and/or provide additional or different license terms for use, reproduction, modification, public display, and redistribution of your modifications and derivative works, provided that such license terms limit the use, reproduction, modification, public display, and redistribution of such modifications and derivative works to noncommercial research purposes only.
9
+ You acknowledge that Adobe and its licensors own all right, title, and interest in the Licensed Materials.
10
+ All copies of the Licensed Materials must include the above copyright notice, this list of conditions, and the disclaimer below.
11
+ Failure to meet any of the above conditions will automatically terminate the rights granted herein.
12
+
13
+ THE LICENSED MATERIALS ARE PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND. THE ENTIRE RISK AS TO THE USE, RESULTS, AND PERFORMANCE OF THE LICENSED MATERIALS IS ASSUMED BY YOU. ADOBE DISCLAIMS ALL WARRANTIES, EXPRESS, IMPLIED OR STATUTORY, WITH REGARD TO YOUR USE OF THE LICENSED MATERIALS, INCLUDING, BUT NOT LIMITED TO, NONINFRINGEMENT OF THIRD-PARTY RIGHTS. IN NO EVENT WILL ADOBE BE LIABLE FOR ANY ACTUAL, INCIDENTAL, SPECIAL OR CONSEQUENTIAL DAMAGES, INCLUDING WITHOUT LIMITATION, LOSS OF PROFITS OR OTHER COMMERCIAL LOSS, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THE LICENSED MATERIALS, EVEN IF ADOBE HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
README.md CHANGED
@@ -1,13 +1,70 @@
1
- ---
2
- title: Oilkkkkbb
3
- emoji: 🖼
4
- colorFrom: purple
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 4.26.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MagicFixup
2
+ This is the repo for the paper [Magic Fixup: Streamlining Photo Editing by Watching Dynamic Videos](https://magic-fixup.github.io)
3
+ ## Installation
4
+ We provide an `environment.yaml` file to assist with installation. All what you need for setup is to run the following script
5
+ ```
6
+ conda env create -f environment.yaml -v
7
+ ```
8
+ and this will create a conda environment that you can activate using `conda activate MagicFixup`
9
+
10
+ ## Inference
11
+
12
+ #### Downloading Magic Fixup checkpoint
13
+ You can download the model trained on the Moments in Time dataset using this [Google Drive Link](https://drive.google.com/file/d/1zOcDcJzCijbGr9I9adC0Cv6yzW60U9TQ/view?usp=share_link)
14
+
15
+
16
+ ### Inference script
17
+ The inference scripts is `run_magicfu.py`. It takes the path of the reference image (the original image), and the edited image. Note that it assumes that the alpha channel is set appropriately in the edited image PNG, as we use the alpha channel to set the disocclusion mask. You can run the inference script with
18
+
19
+ ```
20
+ python run_magicfu.py --checkpoint <Magic Fixup checkpoint> --reference <path to original image> --edit <path to png user edit>
21
+ ```
22
+
23
+ ### gradio demo
24
+ We have a gradio demo that allows you to test out your inputs with a friendly user interface. Simply start the demo with
25
+ ```
26
+ python magicfu_gradio.py --checkpoint <Magic Fixup checkpoint>
27
+ ```
28
+
29
+
30
+ ## Training your own model
31
+ To train your own model, first you need to process a video dataset, train the model using the processed pairs from your videos. In our model, we used the Momnets in Time dataset to train the weights we provided above.
32
+
33
+ #### Pretrained SD1.4 diffusion model
34
+ We start training from the official SD1.4 model (with the first layer modified to take our 9 channel input). You can either download the official SD1.4 model and modify the first layer using `scripts/modify_checkpoints.py` and place it under `pretrained_models` folder.
35
+
36
+ ### Data Processing
37
+ The data processing code can be found under the `data_processing` folder. You can simply put all the videos in a directory, and pass the directory as the folder name in `data_processing/moments_processing.py`. If your videos are long (~ex more than 5 seconds and contain cut scenes), then you would want to use pyscenedetect to detect cut scenes and split the videos accordingly.
38
+ For data processing, you also need to download the checkpoint for SegmentAnything, and install soft-splatting. You can setup softmax-splatting and SAM, by following
39
+ ```
40
+ cd data_processing
41
+ git clone https://github.com/sniklaus/softmax-splatting.git
42
+ pip install segment_anything
43
+ cd sam_model
44
+ wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
45
+ ```
46
+ For softmax-splatting to run, you need to install `pip install cupy` (or you might need to use `pip install cupy-cuda11x` or `pip install cupy-cuda12x` depending on your cuda version, and load the appropriate cuda module)
47
+
48
+ Then run `python moments_processing.py` to start processing frames from the provided examples video (included under `data_processing/example_videos`). For the version provided, we used the [Moments in Time Dataset](http://moments.csail.mit.edu)
49
+
50
+ ### Running the training script
51
+ Make sure that you have downloaded the pretrained SD1.4 model above. Once you download the training dataset and pretrained model, you can simply start training the model with
52
+ ```
53
+ ./train.sh
54
+ ```
55
+ The training code is in `main.py`, and relies mainly on pytorch_lightning in training.
56
+
57
+ Note that you need to modify the train and val paths in the chosen config file to the location where you have the processed data.
58
+
59
+ Note: we use Deepspeed to lower the memory requirements, so the saved model weights will be sharded. The script to reconstruct the model weights will be created in the checkpoint directory with name `zero_to_fp32.py`. One bug in the file is that it wouldn't recognize files with deepspeed1 (which is the one we use), so simply find and replace the string `== 2` with the string `<= 2` and it will work.
60
+
61
+ ### Saving the Full Model Weights
62
+ To save storage requirements, we only checkpoint the learnable parameters in training (i.e. the frozen autoencoder params are not saved). To create a checkpoint that contains all the parameters, you can combine the frozen pretrained weights and learned parameters by running
63
+
64
+ ```
65
+ python combine_model_params.py --pretrained_sd <path to pretrained SD1.4 with modified first layer> --learned_params <path to combined checkpoint learned> --save_path <path to save the >
66
+ ```
67
+
68
+
69
+ ##### Acknowledgement
70
+ The diffusion code was built on top of the codebase adapted in [PaintByExample](https://github.com/Fantasy-Studio/Paint-by-Example)
app.py CHANGED
@@ -1,146 +1,53 @@
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
- import random
4
- from diffusers import DiffusionPipeline
5
- import torch
6
-
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
8
-
9
- if torch.cuda.is_available():
10
- torch.cuda.max_memory_allocated(device=device)
11
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
12
- pipe.enable_xformers_memory_efficient_attention()
13
- pipe = pipe.to(device)
14
- else:
15
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
16
- pipe = pipe.to(device)
17
-
18
- MAX_SEED = np.iinfo(np.int32).max
19
- MAX_IMAGE_SIZE = 1024
20
-
21
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
22
-
23
- if randomize_seed:
24
- seed = random.randint(0, MAX_SEED)
25
 
26
- generator = torch.Generator().manual_seed(seed)
27
-
28
- image = pipe(
29
- prompt = prompt,
30
- negative_prompt = negative_prompt,
31
- guidance_scale = guidance_scale,
32
- num_inference_steps = num_inference_steps,
33
- width = width,
34
- height = height,
35
- generator = generator
36
- ).images[0]
37
-
38
- return image
39
-
40
- examples = [
41
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
42
- "An astronaut riding a green horse",
43
- "A delicious ceviche cheesecake slice",
44
- ]
45
-
46
- css="""
47
- #col-container {
48
- margin: 0 auto;
49
- max-width: 520px;
50
- }
51
- """
52
-
53
- if torch.cuda.is_available():
54
- power_device = "GPU"
55
- else:
56
- power_device = "CPU"
57
-
58
- with gr.Blocks(css=css) as demo:
59
 
60
- with gr.Column(elem_id="col-container"):
61
- gr.Markdown(f"""
62
- # Text-to-Image Gradio Template
63
- Currently running on {power_device}.
64
- """)
65
-
66
- with gr.Row():
67
-
68
- prompt = gr.Text(
69
- label="Prompt",
70
- show_label=False,
71
- max_lines=1,
72
- placeholder="Enter your prompt",
73
- container=False,
74
- )
75
-
76
- run_button = gr.Button("Run", scale=0)
77
-
78
- result = gr.Image(label="Result", show_label=False)
79
-
80
- with gr.Accordion("Advanced Settings", open=False):
81
-
82
- negative_prompt = gr.Text(
83
- label="Negative prompt",
84
- max_lines=1,
85
- placeholder="Enter a negative prompt",
86
- visible=False,
87
- )
88
-
89
- seed = gr.Slider(
90
- label="Seed",
91
- minimum=0,
92
- maximum=MAX_SEED,
93
- step=1,
94
- value=0,
95
- )
96
-
97
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
98
-
99
- with gr.Row():
100
-
101
- width = gr.Slider(
102
- label="Width",
103
- minimum=256,
104
- maximum=MAX_IMAGE_SIZE,
105
- step=32,
106
- value=512,
107
- )
108
-
109
- height = gr.Slider(
110
- label="Height",
111
- minimum=256,
112
- maximum=MAX_IMAGE_SIZE,
113
- step=32,
114
- value=512,
115
- )
116
-
117
- with gr.Row():
118
-
119
- guidance_scale = gr.Slider(
120
- label="Guidance scale",
121
- minimum=0.0,
122
- maximum=10.0,
123
- step=0.1,
124
- value=0.0,
125
- )
126
-
127
- num_inference_steps = gr.Slider(
128
- label="Number of inference steps",
129
- minimum=1,
130
- maximum=12,
131
- step=1,
132
- value=2,
133
- )
134
-
135
- gr.Examples(
136
- examples = examples,
137
- inputs = [prompt]
138
- )
139
-
140
- run_button.click(
141
- fn = infer,
142
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
143
- outputs = [result]
144
- )
145
 
146
- demo.queue().launch()
 
 
 
1
+ # Copyright 2024 Adobe. All rights reserved.
2
+
3
+ from run_magicfu import MagicFixup
4
+ import os
5
+ import pathlib
6
+ import torchvision
7
+ from torch import autocast
8
+ from PIL import Image
9
  import gradio as gr
10
  import numpy as np
11
+ import argparse
12
+
13
+
14
+ def sample(original_image, coarse_edit):
15
+ to_tensor = torchvision.transforms.ToTensor()
16
+ with autocast("cuda"):
17
+ w, h = coarse_edit.size
18
+ ref_image_t = to_tensor(original_image.resize((512,512))).half().cuda()
19
+ coarse_edit_t = to_tensor(coarse_edit.resize((512,512))).half().cuda()
20
+ # get mask from coarse
21
+ coarse_edit_mask_t = to_tensor(coarse_edit.resize((512,512))).half().cuda()
22
+ mask_t = (coarse_edit_mask_t[-1][None, None,...]).half() # do center crop
23
+ coarse_edit_t_rgb = coarse_edit_t[:-1]
 
 
 
 
 
 
 
 
 
24
 
25
+ out_rgb = magic_fixup.edit_image(ref_image_t, coarse_edit_t_rgb, mask_t, start_step=1.0, steps=50)
26
+ output = out_rgb.squeeze().cpu().detach().moveaxis(0, -1).float().numpy()
27
+ output = (output * 255.0).astype(np.uint8)
28
+ output_pil = Image.fromarray(output)
29
+ output_pil = output_pil.resize((w, h))
30
+ return output_pil
31
+
32
+ def file_exists(path):
33
+ """ Check if a file exists and is not a directory. """
34
+ if not os.path.isfile(path):
35
+ raise argparse.ArgumentTypeError(f"{path} is not a valid file.")
36
+ return path
37
+
38
+ def parse_arguments():
39
+ """ Parses command-line arguments. """
40
+ parser = argparse.ArgumentParser(description="Process images based on provided paths.")
41
+ parser.add_argument("--checkpoint", type=file_exists, required=True, help="Path to the MagicFixup checkpoint file.")
42
+
43
+ return parser.parse_args()
44
+
45
+ demo = gr.Interface(fn=sample, inputs=[gr.Image(type="pil", image_mode='RGB'), gr.Image(type="pil", image_mode='RGBA')], outputs=gr.Image(),
46
+ examples='examples')
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ if __name__ == "__main__":
49
+ args = parse_arguments()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ # create magic fixup model
52
+ magic_fixup = MagicFixup(model_path=args.checkpoint)
53
+ demo.launch(share=True)
configs/collage_composite_train.yaml ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Adobe. All rights reserved.
2
+ model:
3
+ base_learning_rate: 1.0e-05
4
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
5
+ params:
6
+ linear_start: 0.00085
7
+ linear_end: 0.0120
8
+ num_timesteps_cond: 1
9
+ log_every_t: 200
10
+ timesteps: 1000
11
+ first_stage_key: "inpaint"
12
+ cond_stage_key: "image"
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: true # Note: different from the one we trained before
16
+ conditioning_key: "rewarp"
17
+ monitor: val/loss_simple_ema
18
+ u_cond_percent: 0.2
19
+ scale_factor: 0.18215
20
+ use_ema: False
21
+ context_embedding_dim: 768 # TODO embedding # 1024 clip, DINO: 'small': 384,'big': 768,'large': 1024,'huge': 1536
22
+
23
+
24
+ scheduler_config: # 10000 warmup steps
25
+ target: ldm.lr_scheduler.LambdaLinearScheduler
26
+ params:
27
+ warm_up_steps: [ 10000 ]
28
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
29
+ f_start: [ 1.e-6 ]
30
+ f_max: [ 1. ]
31
+ f_min: [ 1. ]
32
+
33
+ unet_config:
34
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
35
+ params:
36
+ image_size: 32 # unused
37
+ in_channels: 9
38
+ out_channels: 4
39
+ model_channels: 320
40
+ attention_resolutions: [ 4, 2, 1 ]
41
+ num_res_blocks: 2
42
+ channel_mult: [ 1, 2, 4, 4 ]
43
+ num_heads: 8
44
+ use_spatial_transformer: True
45
+ transformer_depth: 1
46
+ context_dim: 768
47
+ use_checkpoint: True
48
+ legacy: False
49
+ add_conv_in_front_of_unet: False
50
+
51
+ first_stage_config:
52
+ target: ldm.models.autoencoder.AutoencoderKL
53
+ params:
54
+ embed_dim: 4
55
+ monitor: val/rec_loss
56
+ ddconfig:
57
+ double_z: true
58
+ z_channels: 4
59
+ resolution: 256
60
+ in_channels: 3
61
+ out_ch: 3
62
+ ch: 128
63
+ ch_mult:
64
+ - 1
65
+ - 2
66
+ - 4
67
+ - 4
68
+ num_res_blocks: 2
69
+ attn_resolutions: []
70
+ dropout: 0.0
71
+ lossconfig:
72
+ target: torch.nn.Identity
73
+
74
+ cond_stage_config:
75
+ target: ldm.modules.encoders.modules.DINOEmbedder # TODO embedding
76
+ params:
77
+ dino_version: "big" # [small, big, large, huge]
78
+
79
+ data:
80
+ target: main.DataModuleFromConfig
81
+ params:
82
+ batch_size: 2
83
+ num_workers: 8
84
+ use_worker_init_fn: False
85
+ wrap: False
86
+ train:
87
+ target: ldm.data.collage_dataset.CollageDataset
88
+ params:
89
+ split_files: "<specify value train path>"
90
+ image_size: 512
91
+ embedding_type: 'dino' # TODO embedding
92
+ warping_type: 'collage'
93
+ validation:
94
+ target: ldm.data.collage_dataset.CollageDataset
95
+ params:
96
+ split_files: "<specify value val path>"
97
+ image_size: 512
98
+ embedding_type: 'dino' # TODO embedding
99
+ warping_type: 'mix'
100
+ test:
101
+ target: ldm.data.collage_dataset.CollageDataset
102
+ params:
103
+ split_files: "<specify value val path>"
104
+ image_size: 512
105
+ embedding_type: 'dino' # TODO embedding
106
+ warping_type: 'mix'
107
+
108
+ lightning:
109
+ trainer:
110
+ max_epochs: 500
111
+ num_nodes: 1
112
+ num_sanity_val_steps: 0
113
+ accelerator: 'gpu'
114
+ gpus: "0,1,2,3,4,5,6,7"
configs/collage_flow_train.yaml ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Adobe. All rights reserved.
2
+ model:
3
+ base_learning_rate: 1.0e-05
4
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
5
+ params:
6
+ linear_start: 0.00085
7
+ linear_end: 0.0120
8
+ num_timesteps_cond: 1
9
+ log_every_t: 200
10
+ timesteps: 1000
11
+ first_stage_key: "inpaint"
12
+ cond_stage_key: "image"
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: true # Note: different from the one we trained before
16
+ conditioning_key: "rewarp"
17
+ monitor: val/loss_simple_ema
18
+ u_cond_percent: 0.2
19
+ scale_factor: 0.18215
20
+ use_ema: False
21
+ context_embedding_dim: 768 # TODO embedding # 1024 clip, DINO: 'small': 384,'big': 768,'large': 1024,'huge': 1536
22
+
23
+
24
+ scheduler_config: # 10000 warmup steps
25
+ target: ldm.lr_scheduler.LambdaLinearScheduler
26
+ params:
27
+ warm_up_steps: [ 10000 ]
28
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
29
+ f_start: [ 1.e-6 ]
30
+ f_max: [ 1. ]
31
+ f_min: [ 1. ]
32
+
33
+ unet_config:
34
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
35
+ params:
36
+ image_size: 32 # unused
37
+ in_channels: 9
38
+ out_channels: 4
39
+ model_channels: 320
40
+ attention_resolutions: [ 4, 2, 1 ]
41
+ num_res_blocks: 2
42
+ channel_mult: [ 1, 2, 4, 4 ]
43
+ num_heads: 8
44
+ use_spatial_transformer: True
45
+ transformer_depth: 1
46
+ context_dim: 768
47
+ use_checkpoint: True
48
+ legacy: False
49
+ add_conv_in_front_of_unet: False
50
+
51
+ first_stage_config:
52
+ target: ldm.models.autoencoder.AutoencoderKL
53
+ params:
54
+ embed_dim: 4
55
+ monitor: val/rec_loss
56
+ ddconfig:
57
+ double_z: true
58
+ z_channels: 4
59
+ resolution: 256
60
+ in_channels: 3
61
+ out_ch: 3
62
+ ch: 128
63
+ ch_mult:
64
+ - 1
65
+ - 2
66
+ - 4
67
+ - 4
68
+ num_res_blocks: 2
69
+ attn_resolutions: []
70
+ dropout: 0.0
71
+ lossconfig:
72
+ target: torch.nn.Identity
73
+
74
+ cond_stage_config:
75
+ target: ldm.modules.encoders.modules.DINOEmbedder # TODO embedding
76
+ params:
77
+ dino_version: "big" # [small, big, large, huge]
78
+
79
+ data:
80
+ target: main.DataModuleFromConfig
81
+ params:
82
+ batch_size: 2
83
+ num_workers: 8
84
+ use_worker_init_fn: False
85
+ wrap: False
86
+ train:
87
+ target: ldm.data.collage_dataset.CollageDataset
88
+ params:
89
+ split_files: /mnt/localssd/new_train
90
+ image_size: 512
91
+ embedding_type: 'dino' # TODO embedding
92
+ warping_type: 'flow'
93
+ validation:
94
+ target: ldm.data.collage_dataset.CollageDataset
95
+ params:
96
+ split_files: /mnt/localssd/new_val
97
+ image_size: 512
98
+ embedding_type: 'dino' # TODO embedding
99
+ warping_type: 'mix'
100
+ test:
101
+ target: ldm.data.collage_dataset.CollageDataset
102
+ params:
103
+ split_files: /mnt/localssd/new_val
104
+ image_size: 512
105
+ embedding_type: 'dino' # TODO embedding
106
+ warping_type: 'mix'
107
+
108
+ lightning:
109
+ trainer:
110
+ max_epochs: 500
111
+ num_nodes: 1
112
+ num_sanity_val_steps: 0
113
+ accelerator: 'gpu'
114
+ gpus: "0,1,2,3,4,5,6,7"
configs/collage_mix_train.yaml ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Adobe. All rights reserved.
2
+ model:
3
+ base_learning_rate: 1.0e-05
4
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
5
+ params:
6
+ linear_start: 0.00085
7
+ linear_end: 0.0120
8
+ num_timesteps_cond: 1
9
+ log_every_t: 200
10
+ timesteps: 1000
11
+ first_stage_key: "inpaint"
12
+ cond_stage_key: "image"
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: true # Note: different from the one we trained before
16
+ conditioning_key: "rewarp"
17
+ monitor: val/loss_simple_ema
18
+ u_cond_percent: 0.2
19
+ scale_factor: 0.18215
20
+ use_ema: False
21
+ context_embedding_dim: 384 # TODO embedding # 1024 clip, DINO: 'small': 384,'big': 768,'large': 1024,'huge': 1536
22
+ dropping_warped_latent_prob: 0.2
23
+
24
+
25
+ scheduler_config: # 10000 warmup steps
26
+ target: ldm.lr_scheduler.LambdaLinearScheduler
27
+ params:
28
+ warm_up_steps: [ 10000 ]
29
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
30
+ f_start: [ 1.e-6 ]
31
+ f_max: [ 1. ]
32
+ f_min: [ 1. ]
33
+
34
+ unet_config:
35
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
36
+ params:
37
+ image_size: 32 # unused
38
+ in_channels: 9
39
+ out_channels: 4
40
+ model_channels: 320
41
+ attention_resolutions: [ 4, 2, 1 ]
42
+ num_res_blocks: 2
43
+ channel_mult: [ 1, 2, 4, 4 ]
44
+ num_heads: 8
45
+ use_spatial_transformer: True
46
+ transformer_depth: 1
47
+ context_dim: 768
48
+ use_checkpoint: True
49
+ legacy: False
50
+ add_conv_in_front_of_unet: False
51
+
52
+ first_stage_config:
53
+ target: ldm.models.autoencoder.AutoencoderKL
54
+ params:
55
+ embed_dim: 4
56
+ monitor: val/rec_loss
57
+ ddconfig:
58
+ double_z: true
59
+ z_channels: 4
60
+ resolution: 256
61
+ in_channels: 3
62
+ out_ch: 3
63
+ ch: 128
64
+ ch_mult:
65
+ - 1
66
+ - 2
67
+ - 4
68
+ - 4
69
+ num_res_blocks: 2
70
+ attn_resolutions: []
71
+ dropout: 0.0
72
+ lossconfig:
73
+ target: torch.nn.Identity
74
+
75
+ cond_stage_config:
76
+ target: ldm.modules.encoders.modules.DINOEmbedder # TODO embedding
77
+ params:
78
+ dino_version: "small" # [small, big, large, huge]
79
+
80
+ data:
81
+ target: main.DataModuleFromConfig
82
+ params:
83
+ batch_size: 4
84
+ num_workers: 8
85
+ use_worker_init_fn: False
86
+ wrap: False
87
+ train:
88
+ target: ldm.data.collage_dataset.CollageDataset
89
+ params:
90
+ split_files: /mnt/localssd/new_train
91
+ image_size: 512
92
+ embedding_type: 'dino' # TODO embedding
93
+ warping_type: 'mix'
94
+ validation:
95
+ target: ldm.data.collage_dataset.CollageDataset
96
+ params:
97
+ split_files: /mnt/localssd/new_val
98
+ image_size: 512
99
+ embedding_type: 'dino' # TODO embedding
100
+ warping_type: 'mix'
101
+ test:
102
+ target: ldm.data.collage_dataset.CollageDataset
103
+ params:
104
+ split_files: /mnt/localssd/new_val
105
+ image_size: 512
106
+ embedding_type: 'dino' # TODO embedding
107
+ warping_type: 'mix'
108
+
109
+ lightning:
110
+ trainer:
111
+ max_epochs: 500
112
+ num_nodes: 1
113
+ num_sanity_val_steps: 0
114
+ accelerator: 'gpu'
115
+ gpus: "0,1,2,3,4,5,6,7"
data_processing/example_videos/getty-soccer-ball-jordan-video-id473239807_26.mp4 ADDED
Binary file (180 kB). View file
 
data_processing/example_videos/getty-video-of-american-flags-being-sewn-together-at-flagsource-in-batavia-video-id804937470_87.mp4 ADDED
Binary file (427 kB). View file
 
data_processing/example_videos/giphy-fgiT2cbsTxl8k_0.mp4 ADDED
Binary file (95.2 kB). View file
 
data_processing/example_videos/giphy-gkvCpHRX9IqkM_3.mp4 ADDED
Binary file (102 kB). View file
 
data_processing/example_videos/yt--4Fx5XUD-9Y_345.mp4 ADDED
Binary file (686 kB). View file
 
data_processing/example_videos/yt-mNdvtOO7UqY_15.mp4 ADDED
Binary file (301 kB). View file
 
data_processing/moments_dataset.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Adobe. All rights reserved.
2
+
3
+ #%%
4
+ import glob
5
+ import torch
6
+ import torchvision
7
+ import matplotlib.pyplot as plt
8
+ from torch.utils.data import Dataset
9
+ import numpy as np
10
+
11
+
12
+ # %%
13
+ class MomentsDataset(Dataset):
14
+ def __init__(self, videos_folder, num_frames, samples_per_video, frame_size=512) -> None:
15
+ super().__init__()
16
+
17
+ self.videos_paths = glob.glob(f'{videos_folder}/*mp4')
18
+ self.resize = torchvision.transforms.Resize(size=frame_size)
19
+ self.center_crop = torchvision.transforms.CenterCrop(size=frame_size)
20
+ self.num_samples_per_video = samples_per_video
21
+ self.num_frames = num_frames
22
+
23
+ def __len__(self):
24
+ return len(self.videos_paths) * self.num_samples_per_video
25
+
26
+ def __getitem__(self, idx):
27
+ video_idx = idx // self.num_samples_per_video
28
+ video_path = self.videos_paths[video_idx]
29
+
30
+ try:
31
+ start_idx = np.random.randint(0, 20)
32
+
33
+ unsampled_video_frames, audio_frames, info = torchvision.io.read_video(video_path,output_format="TCHW")
34
+ sampled_indices = torch.tensor(np.linspace(start_idx, len(unsampled_video_frames)-1, self.num_frames).astype(int))
35
+ sampled_frames = unsampled_video_frames[sampled_indices]
36
+ processed_frames = []
37
+
38
+ for frame in sampled_frames:
39
+ resized_cropped_frame = self.center_crop(self.resize(frame))
40
+ processed_frames.append(resized_cropped_frame)
41
+ frames = torch.stack(processed_frames, dim=0)
42
+ frames = frames.float() / 255.0
43
+ except Exception as e:
44
+ print('oops', e)
45
+ rand_idx = np.random.randint(0, len(self))
46
+ return self.__getitem__(rand_idx)
47
+
48
+ out_dict = {'frames': frames,
49
+ 'caption': 'none',
50
+ 'keywords': 'none'}
51
+
52
+ return out_dict
53
+
54
+
data_processing/moments_processing.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Adobe. All rights reserved.
2
+
3
+ #%%
4
+ from torchvision.transforms import ToPILImage
5
+ import torch
6
+ import pandas as pd
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import torchvision
10
+ import cv2
11
+ import tqdm
12
+ import matplotlib.pyplot as plt
13
+ import torchvision.transforms.functional as F
14
+ from PIL import Image
15
+ from torchvision.utils import save_image
16
+ import time
17
+ import os
18
+ import sys
19
+ import pathlib
20
+ from torchvision.utils import flow_to_image
21
+ from torch.utils.data import DataLoader
22
+ from einops import rearrange
23
+ # %matplotlib inline
24
+ from kornia.filters.median import MedianBlur
25
+ median_filter = MedianBlur(kernel_size=(15,15))
26
+ from moments_dataset import MomentsDataset
27
+
28
+ try:
29
+ from processing_utils import aggregate_frames
30
+ import processing_utils
31
+ except Exception as e:
32
+ print(e)
33
+ print('process failed')
34
+ exit()
35
+
36
+
37
+
38
+
39
+ import pytorch_lightning as pl
40
+ import torch
41
+ from omegaconf import OmegaConf
42
+
43
+ # %%
44
+
45
+ def load_image(img_path, resize_size=None,crop_size=None):
46
+
47
+ img1_pil = Image.open(img_path)
48
+ img1_frames = torchvision.transforms.functional.pil_to_tensor(img1_pil)
49
+
50
+ if resize_size:
51
+ img1_frames = torchvision.transforms.functional.resize(img1_frames, resize_size)
52
+
53
+ if crop_size:
54
+ img1_frames = torchvision.transforms.functional.center_crop(img1_frames, crop_size)
55
+
56
+ img1_batch = torch.unsqueeze(img1_frames, dim=0)
57
+
58
+ return img1_batch
59
+
60
+ def get_grid(size):
61
+ y = np.repeat(np.arange(size)[None, ...], size)
62
+ y = y.reshape(size, size)
63
+ x = y.transpose()
64
+ out = np.stack([y,x], -1)
65
+ return out
66
+
67
+ def collage_from_frames(frames_t):
68
+ # decide forward or backward
69
+ if np.random.randint(0, 2) == 0:
70
+ # flip
71
+ frames_t = frames_t.flip(0)
72
+
73
+ # decide how deep you would go
74
+ tgt_idx_guess = np.random.randint(1, min(len(frames_t), 20))
75
+ tgt_idx = 1
76
+ pairwise_flows = []
77
+ flow = None
78
+ init_time = time.time()
79
+ unsmoothed_agg = None
80
+ for cur_idx in range(1, tgt_idx_guess+1):
81
+ # cur_idx = i+1
82
+ cur_flow, pairwise_flows = aggregate_frames(frames_t[:cur_idx+1] , pairwise_flows, unsmoothed_agg) # passing pairwise flows for efficiency
83
+ unsmoothed_agg = cur_flow.clone()
84
+ agg_cur_flow = median_filter(cur_flow)
85
+
86
+ flow_norm = torch.norm(agg_cur_flow.squeeze(), dim=0).flatten()
87
+ # flow_10 = np.percentile(flow_norm.cpu().numpy(), 10)
88
+ flow_90 = np.percentile(flow_norm.cpu().numpy(), 90)
89
+
90
+ # flow_10 = np.percentile(flow_norm.cpu().numpy(), 10)
91
+ flow_90 = np.percentile(flow_norm.cpu().numpy(), 90)
92
+ flow_95 = np.percentile(flow_norm.cpu().numpy(), 95)
93
+
94
+ if cur_idx == 5: # if still small flow then drop
95
+ if flow_95 < 20.0:
96
+ # no motion in the frame. skip
97
+ print('flow is tiny :(')
98
+ return None
99
+
100
+ if cur_idx == tgt_idx_guess-1: # if still small flow then drop
101
+ if flow_95 < 50.0:
102
+ # no motion in the frame. skip
103
+ print('flow is tiny :(')
104
+ return None
105
+
106
+ if flow is None: # means first iter
107
+ if flow_90 < 1.0:
108
+ # no motion in the frame. skip
109
+ return None
110
+ flow = agg_cur_flow
111
+
112
+ if flow_90 <= 300: # maybe should increase this part
113
+ # update idx
114
+ tgt_idx = cur_idx
115
+ flow = agg_cur_flow
116
+ else:
117
+ break
118
+ final_time = time.time()
119
+ print('time guessing idx', final_time - init_time)
120
+
121
+ _, flow_warping_mask = processing_utils.forward_warp(frames_t[0], frames_t[tgt_idx], flow, grid=None, alpha_mask=None)
122
+ flow_warping_mask = flow_warping_mask.squeeze().numpy() > 0.5
123
+
124
+ if np.mean(flow_warping_mask) < 0.6:
125
+ return
126
+
127
+
128
+ src_array = frames_t[0].moveaxis(0, -1).cpu().numpy() * 1.0
129
+ init_time = time.time()
130
+ depth = get_depth_from_array(frames_t[0])
131
+ finish_time = time.time()
132
+ print('time getting depth', finish_time - init_time)
133
+ # flow, pairwise_flows = aggregate_frames(frames_t)
134
+ # agg_flow = median_filter(flow)
135
+
136
+ src_array_uint = src_array * 255.0
137
+ src_array_uint = src_array_uint.astype(np.uint8)
138
+ segments = processing_utils.mask_generator.generate(src_array_uint)
139
+
140
+ size = src_array.shape[1]
141
+ grid_np = get_grid(size).astype(np.float16) / size # 512 x 512 x 2get
142
+ grid_t = torch.tensor(grid_np).moveaxis(-1, 0) # 512 x 512 x 2
143
+
144
+
145
+ collage, canvas_alpha, lost_alpha = collage_warp(src_array, flow.squeeze(), depth, segments, grid_array=grid_np)
146
+ lost_alpha_t = torch.tensor(lost_alpha).squeeze().unsqueeze(0)
147
+ warping_alpha = (lost_alpha_t < 0.5).float()
148
+
149
+ rgb_grid_splatted, actual_warped_mask = processing_utils.forward_warp(frames_t[0], frames_t[tgt_idx], flow, grid=grid_t, alpha_mask=warping_alpha)
150
+
151
+
152
+ # basic blending now
153
+ # print('rgb grid splatted', rgb_grid_splatted.shape)
154
+ warped_src = (rgb_grid_splatted * actual_warped_mask).moveaxis(0, -1).cpu().numpy()
155
+ canvas_alpha_mask = canvas_alpha == 0.0
156
+ collage_mask = canvas_alpha.squeeze() + actual_warped_mask.squeeze().cpu().numpy()
157
+ collage_mask = collage_mask > 0.5
158
+
159
+ composite_grid = warped_src * canvas_alpha_mask + collage
160
+ rgb_grid_splatted_np = rgb_grid_splatted.moveaxis(0, -1).cpu().numpy()
161
+
162
+ return frames_t[0], frames_t[tgt_idx], rgb_grid_splatted_np, composite_grid, flow_warping_mask, collage_mask
163
+
164
+ def collage_warp(rgb_array, flow, depth, segments, grid_array):
165
+ avg_depths = []
166
+ avg_flows = []
167
+
168
+ # src_array = src_array.moveaxis(-1, 0).cpu().numpy() #np.array(Image.open(src_path).convert('RGB')) / 255.0
169
+ src_array = np.concatenate([rgb_array, grid_array], axis=-1)
170
+ canvas = np.zeros_like(src_array)
171
+ canvas_alpha = np.zeros_like(canvas[...,-1:]).astype(float)
172
+ lost_regions = np.zeros_like(canvas[...,-1:]).astype(float)
173
+ z_buffer = np.ones_like(depth)[..., None] * -1.0
174
+ unsqueezed_depth = depth[..., None]
175
+
176
+ affine_transforms = []
177
+
178
+ filtered_segments = []
179
+ for segment in segments:
180
+ if segment['area'] > 300:
181
+ filtered_segments.append(segment)
182
+
183
+ for segment in filtered_segments:
184
+ seg_mask = segment['segmentation']
185
+ avg_flow = torch.mean(flow[:, seg_mask],dim=1)
186
+ avg_flows.append(avg_flow)
187
+ # median depth (conversion from disparity)
188
+ avg_depth = torch.median(1.0 / (depth[seg_mask] + 1e-6))
189
+ avg_depths.append(avg_depth)
190
+
191
+ all_y, all_x = np.nonzero(segment['segmentation'])
192
+ rand_indices = np.random.randint(0, len(all_y), size=50)
193
+ rand_x = [all_x[i] for i in rand_indices]
194
+ rand_y = [all_y[i] for i in rand_indices]
195
+
196
+ src_pairs = [(x, y) for x, y in zip(rand_x, rand_y)]
197
+ # tgt_pairs = [(x + w, y) for x, y in src_pairs]
198
+ tgt_pairs = []
199
+ # print('estimating affine') # TODO this can be faster
200
+ for i in range(len(src_pairs)):
201
+ x, y = src_pairs[i]
202
+ dx, dy = flow[:, y, x]
203
+ tgt_pairs.append((x+dx, y+dy))
204
+
205
+ # affine_trans, inliers = cv2.estimateAffine2D(np.array(src_pairs).astype(np.float32), np.array(tgt_pairs).astype(np.float32))
206
+ affine_trans, inliers = cv2.estimateAffinePartial2D(np.array(src_pairs).astype(np.float32), np.array(tgt_pairs).astype(np.float32))
207
+ # print('num inliers', np.sum(inliers))
208
+ # # print('num inliers', np.sum(inliers))
209
+ affine_transforms.append(affine_trans)
210
+
211
+ depth_sorted_indices = np.arange(len(avg_depths))
212
+ depth_sorted_indices = sorted(depth_sorted_indices, key=lambda x: avg_depths[x])
213
+ # sorted_masks = []
214
+ # print('warping stuff')
215
+ for idx in depth_sorted_indices:
216
+ # sorted_masks.append(mask[idx])
217
+ alpha_mask = filtered_segments[idx]['segmentation'][..., None] * (lost_regions < 0.5).astype(float)
218
+ src_rgba = np.concatenate([src_array, alpha_mask, unsqueezed_depth], axis=-1)
219
+ warp_dst = cv2.warpAffine(src_rgba, affine_transforms[idx], (src_array.shape[1], src_array.shape[0]))
220
+ warped_mask = warp_dst[..., -2:-1] # this is warped alpha
221
+ warped_depth = warp_dst[..., -1:]
222
+ warped_rgb = warp_dst[...,:-2]
223
+
224
+ good_z_region = warped_depth > z_buffer
225
+
226
+ warped_mask = np.logical_and(warped_mask > 0.5, good_z_region).astype(float)
227
+
228
+ kernel = np.ones((3,3), float)
229
+ # print('og masked shape', warped_mask.shape)
230
+ # warped_mask = cv2.erode(warped_mask,(5,5))[..., None]
231
+ # print('eroded masked shape', warped_mask.shape)
232
+ canvas_alpha += cv2.erode(warped_mask,kernel)[..., None]
233
+
234
+ lost_regions += alpha_mask
235
+ canvas = canvas * (1.0 - warped_mask) + warped_mask * warped_rgb # TODO check if need to dialate here
236
+ z_buffer = z_buffer * (1.0 - warped_mask) + warped_mask * warped_depth # TODO check if need to dialate here # print('max lost region', np.max(lost_regions))
237
+ return canvas, canvas_alpha, lost_regions
238
+
239
+ def get_depth_from_array(img_t):
240
+ img_arr = img_t.moveaxis(0, -1).cpu().numpy() * 1.0
241
+ # print(img_arr.shape)
242
+ img_arr *= 255.0
243
+ img_arr = img_arr.astype(np.uint8)
244
+ input_batch = processing_utils.depth_transform(img_arr).cuda()
245
+
246
+ with torch.no_grad():
247
+ prediction = processing_utils.midas(input_batch)
248
+
249
+ prediction = torch.nn.functional.interpolate(
250
+ prediction.unsqueeze(1),
251
+ size=img_arr.shape[:2],
252
+ mode="bicubic",
253
+ align_corners=False,
254
+ ).squeeze()
255
+
256
+ output = prediction.cpu()
257
+ return output
258
+
259
+
260
+ # %%
261
+
262
+ def main():
263
+ print('starting main')
264
+ video_folder = './example_videos'
265
+ save_dir = pathlib.Path('./processed_data')
266
+ process_video_folder(video_folder, save_dir)
267
+
268
+ def process_video_folder(video_folder, save_dir):
269
+ all_counter = 0
270
+ success_counter = 0
271
+
272
+ # save_folder = pathlib.Path('/dev/shm/processed')
273
+ # save_dir = save_folder / foldername #pathlib.Path('/sensei-fs/users/halzayer/collage2photo/testing_partitioning_dilate_extreme')
274
+ os.makedirs(save_dir, exist_ok=True)
275
+
276
+ dataset = MomentsDataset(videos_folder=video_folder, num_frames=20, samples_per_video=5)
277
+ batch_size = 4
278
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
279
+
280
+ with torch.no_grad():
281
+ for i, batch in tqdm.tqdm(enumerate(dataloader), total=len(dataset)//batch_size):
282
+ frames_to_visualize = batch["frames"]
283
+ bs = frames_to_visualize.shape[0]
284
+
285
+ for j in range(bs):
286
+ frames = frames_to_visualize[j]
287
+ caption = batch["caption"][j]
288
+
289
+ collage_init_time = time.time()
290
+ out = collage_from_frames(frames)
291
+ collage_finish_time = time.time()
292
+ print('collage processing time', collage_finish_time - collage_init_time)
293
+ all_counter += 1
294
+ if out is not None:
295
+ src_image, tgt_image, splatted, collage, flow_mask, collage_mask = out
296
+
297
+ splatted_rgb = splatted[...,:3]
298
+ splatted_grid = splatted[...,3:].astype(np.float16)
299
+
300
+ collage_rgb = collage[...,:3]
301
+ collage_grid = collage[...,3:].astype(np.float16)
302
+ success_counter += 1
303
+ else:
304
+ continue
305
+
306
+ id_str = f'{success_counter:08d}'
307
+
308
+ src_path = str(save_dir / f'src_{id_str}.png')
309
+ tgt_path = str(save_dir / f'tgt_{id_str}.png')
310
+ flow_warped_path = str(save_dir / f'flow_warped_{id_str}.png')
311
+ composite_path = str(save_dir / f'composite_{id_str}.png')
312
+ flow_mask_path = str(save_dir / f'flow_mask_{id_str}.png')
313
+ composite_mask_path = str(save_dir / f'composite_mask_{id_str}.png')
314
+
315
+ flow_grid_path = str(save_dir / f'flow_warped_grid_{id_str}.npy')
316
+ composite_grid_path = str(save_dir / f'composite_grid_{id_str}.npy')
317
+
318
+ save_image(src_image, src_path)
319
+ save_image(tgt_image, tgt_path)
320
+
321
+ collage_pil = Image.fromarray((collage_rgb * 255).astype(np.uint8))
322
+ collage_pil.save(composite_path)
323
+
324
+ splatted_pil = Image.fromarray((splatted_rgb * 255).astype(np.uint8))
325
+ splatted_pil.save(flow_warped_path)
326
+
327
+ flow_mask_pil = Image.fromarray((flow_mask.astype(float) * 255).astype(np.uint8))
328
+ flow_mask_pil.save(flow_mask_path)
329
+
330
+ composite_mask_pil = Image.fromarray((collage_mask.astype(float) * 255).astype(np.uint8))
331
+ composite_mask_pil.save(composite_mask_path)
332
+
333
+ splatted_grid_t = torch.tensor(splatted_grid).moveaxis(-1, 0)
334
+ splatted_grid_resized = torchvision.transforms.functional.resize(splatted_grid_t, (64,64))
335
+
336
+ collage_grid_t = torch.tensor(collage_grid).moveaxis(-1, 0)
337
+ collage_grid_resized = torchvision.transforms.functional.resize(collage_grid_t, (64,64))
338
+ np.save(flow_grid_path, splatted_grid_resized.cpu().numpy())
339
+ np.save(composite_grid_path, collage_grid_resized.cpu().numpy())
340
+
341
+
342
+ del out
343
+ del splatted_grid
344
+ del collage_grid
345
+ del frames
346
+
347
+ del frames_to_visualize
348
+
349
+
350
+
351
+ #%%
352
+
353
+ if __name__ == '__main__':
354
+ try:
355
+ main()
356
+ except Exception as e:
357
+ print(e)
358
+ print('process failed')
359
+
data_processing/processing_utils.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import numpy as np
4
+ import sys
5
+ import torchvision
6
+ from PIL import Image
7
+ from torchvision.models.optical_flow import Raft_Large_Weights
8
+ from torchvision.models.optical_flow import raft_large
9
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
10
+ import matplotlib.pyplot as plt
11
+ import torchvision.transforms.functional as F
12
+ sys.path.append('./softmax-splatting')
13
+ import softsplat
14
+
15
+
16
+ sam_checkpoint = "./sam_model/sam_vit_h_4b8939.pth"
17
+ model_type = "vit_h"
18
+
19
+ device = "cuda"
20
+
21
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
22
+ sam.to(device=device)
23
+ # mask_generator = SamAutomaticMaskGenerator(sam,
24
+ # crop_overlap_ratio=0.05,
25
+ # box_nms_thresh=0.2,
26
+ # points_per_side=32,
27
+ # pred_iou_thresh=0.86,
28
+ # stability_score_thresh=0.8,
29
+
30
+ # min_mask_region_area=100,)
31
+ # mask_generator = SamAutomaticMaskGenerator(sam)
32
+ mask_generator = SamAutomaticMaskGenerator(sam,
33
+ # box_nms_thresh=0.5,
34
+ # crop_overlap_ratio=0.75,
35
+ # min_mask_region_area=200,
36
+ )
37
+
38
+ def get_mask(img_path):
39
+ image = cv2.imread(img_path)
40
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
41
+ masks = mask_generator.generate(image)
42
+ return masks
43
+
44
+ def get_mask_from_array(arr):
45
+ return mask_generator.generate(arr)
46
+
47
+ # depth model
48
+
49
+ import cv2
50
+ import torch
51
+ import urllib.request
52
+
53
+ import matplotlib.pyplot as plt
54
+
55
+ # potentially downgrade this. just need rough depths. benchmark this
56
+ model_type = "DPT_Large" # MiDaS v3 - Large (highest accuracy, slowest inference speed)
57
+ #model_type = "DPT_Hybrid" # MiDaS v3 - Hybrid (medium accuracy, medium inference speed)
58
+ #model_type = "MiDaS_small" # MiDaS v2.1 - Small (lowest accuracy, highest inference speed)
59
+
60
+ # midas = torch.hub.load("intel-isl/MiDaS", model_type)
61
+ midas = torch.hub.load("/sensei-fs/users/halzayer/collage2photo/model_cache/intel-isl_MiDaS_master", model_type, source='local')
62
+
63
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
64
+ midas.to(device)
65
+ midas.eval()
66
+
67
+ # midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
68
+ midas_transforms = torch.hub.load("/sensei-fs/users/halzayer/collage2photo/model_cache/intel-isl_MiDaS_master", "transforms", source='local')
69
+
70
+ if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
71
+ depth_transform = midas_transforms.dpt_transform
72
+ else:
73
+ depth_transform = midas_transforms.small_transform
74
+
75
+ # img_path = '/sensei-fs/users/halzayer/valid/JPEGImages/45597680/00005.jpg'
76
+ def get_depth(img_path):
77
+ img = cv2.imread(img_path)
78
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
79
+
80
+ input_batch = depth_transform(img).to(device)
81
+
82
+ with torch.no_grad():
83
+ prediction = midas(input_batch)
84
+
85
+ prediction = torch.nn.functional.interpolate(
86
+ prediction.unsqueeze(1),
87
+ size=img.shape[:2],
88
+ mode="bicubic",
89
+ align_corners=False,
90
+ ).squeeze()
91
+
92
+ output = prediction.cpu()
93
+ return output
94
+
95
+ def get_depth_from_array(img):
96
+ input_batch = depth_transform(img).to(device)
97
+
98
+ with torch.no_grad():
99
+ prediction = midas(input_batch)
100
+
101
+ prediction = torch.nn.functional.interpolate(
102
+ prediction.unsqueeze(1),
103
+ size=img.shape[:2],
104
+ mode="bicubic",
105
+ align_corners=False,
106
+ ).squeeze()
107
+
108
+ output = prediction.cpu()
109
+ return output
110
+
111
+
112
+ def load_image(img_path):
113
+ img1_names = [img_path]
114
+
115
+ img1_pil = [Image.open(fn) for fn in img1_names]
116
+ img1_frames = [torchvision.transforms.functional.pil_to_tensor(fn) for fn in img1_pil]
117
+
118
+ img1_batch = torch.stack(img1_frames)
119
+
120
+ return img1_batch
121
+
122
+ weights = Raft_Large_Weights.DEFAULT
123
+ transforms = weights.transforms()
124
+
125
+ device = "cuda" if torch.cuda.is_available() else "cpu"
126
+
127
+ model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device)
128
+ model = model.eval()
129
+
130
+ print('created model')
131
+
132
+ def preprocess(img1_batch, img2_batch, size=[520,960], transform_batch=True):
133
+ img1_batch = F.resize(img1_batch, size=size, antialias=False)
134
+ img2_batch = F.resize(img2_batch, size=size, antialias=False)
135
+ if transform_batch:
136
+ return transforms(img1_batch, img2_batch)
137
+ else:
138
+ return img1_batch, img2_batch
139
+
140
+ def compute_flow(img_path_1, img_path_2):
141
+ img1_batch_og, img2_batch_og = load_image(img_path_1), load_image(img_path_2)
142
+ B, C, H, W = img1_batch_og.shape
143
+
144
+ img1_batch, img2_batch = preprocess(img1_batch_og, img2_batch_og, transform_batch=False)
145
+ img1_batch_t, img2_batch_t = transforms(img1_batch, img2_batch)
146
+
147
+ # If you can, run this example on a GPU, it will be a lot faster.
148
+ with torch.no_grad():
149
+ list_of_flows = model(img1_batch_t.to(device), img2_batch_t.to(device))
150
+ predicted_flows = list_of_flows[-1]
151
+ # flows.append(predicted_flows)
152
+
153
+ resized_flow = F.resize(predicted_flows, size=(H, W), antialias=False)
154
+
155
+ _, _, flow_H, flow_W = predicted_flows.shape
156
+
157
+ resized_flow[:,0] *= (W / flow_W)
158
+ resized_flow[:,1] *= (H / flow_H)
159
+
160
+ return resized_flow.detach().cpu().squeeze()
161
+
162
+ def compute_flow_from_tensors(img1_batch_og, img2_batch_og):
163
+ if len(img1_batch_og.shape) < 4:
164
+ img1_batch_og = img1_batch_og.unsqueeze(0)
165
+ if len(img2_batch_og.shape) < 4:
166
+ img2_batch_og = img2_batch_og.unsqueeze(0)
167
+
168
+ B, C, H, W = img1_batch_og.shape
169
+ img1_batch, img2_batch = preprocess(img1_batch_og, img2_batch_og, transform_batch=False)
170
+ img1_batch_t, img2_batch_t = transforms(img1_batch, img2_batch)
171
+
172
+ # If you can, run this example on a GPU, it will be a lot faster.
173
+ with torch.no_grad():
174
+ list_of_flows = model(img1_batch_t.to(device), img2_batch_t.to(device))
175
+ predicted_flows = list_of_flows[-1]
176
+ # flows.append(predicted_flows)
177
+
178
+ resized_flow = F.resize(predicted_flows, size=(H, W), antialias=False)
179
+
180
+ _, _, flow_H, flow_W = predicted_flows.shape
181
+
182
+ resized_flow[:,0] *= (W / flow_W)
183
+ resized_flow[:,1] *= (H / flow_H)
184
+
185
+ return resized_flow.detach().cpu().squeeze()
186
+
187
+
188
+
189
+ # import run
190
+ backwarp_tenGrid = {}
191
+
192
+ def backwarp(tenIn, tenFlow):
193
+ if str(tenFlow.shape) not in backwarp_tenGrid:
194
+ tenHor = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[3], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, 1, -1).repeat(1, 1, tenFlow.shape[2], 1)
195
+ tenVer = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[2], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, -1, 1).repeat(1, 1, 1, tenFlow.shape[3])
196
+
197
+ backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([tenHor, tenVer], 1).cuda()
198
+ # end
199
+
200
+ tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenIn.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenIn.shape[2] - 1.0) / 2.0)], 1)
201
+
202
+ return torch.nn.functional.grid_sample(input=tenIn, grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=True)
203
+
204
+ torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance
205
+
206
+ ##########################################################
207
+ def forward_splt(src, tgt, flow, partial=False):
208
+ tenTwo = tgt.unsqueeze(0).cuda() #torch.FloatTensor(numpy.ascontiguousarray(cv2.imread(filename='./images/one.png', flags=-1).transpose(2, 0, 1)[None, :, :, :].astype(numpy.float32) * (1.0 / 255.0))).cuda()
209
+ tenOne = src.unsqueeze(0).cuda() #torch.FloatTensor(numpy.ascontiguousarray(cv2.imread(filename='./images/two.png', flags=-1).transpose(2, 0, 1)[None, :, :, :].astype(numpy.float32) * (1.0 / 255.0))).cuda()
210
+ tenFlow = flow.unsqueeze(0).cuda() #torch.FloatTensor(numpy.ascontiguousarray(run.read_flo('./images/flow.flo').transpose(2, 0, 1)[None, :, :, :])).cuda()
211
+
212
+ if not partial:
213
+ tenMetric = torch.nn.functional.l1_loss(input=tenOne, target=backwarp(tenIn=tenTwo, tenFlow=tenFlow), reduction='none').mean([1], True)
214
+ else:
215
+ tenMetric = torch.nn.functional.l1_loss(input=tenOne[:,:3], target=backwarp(tenIn=tenTwo[:,:3], tenFlow=tenFlow[:,:3]), reduction='none').mean([1], True)
216
+ # for intTime, fltTime in enumerate(np.linspace(0.0, 1.0, 11).tolist()):
217
+ tenSoftmax = softsplat.softsplat(tenIn=tenOne, tenFlow=tenFlow , tenMetric=(-20.0 * tenMetric).clip(-20.0, 20.0), strMode='soft') # -20.0 is a hyperparameter, called 'alpha' in the paper, that could be learned using a torch.Parameter
218
+
219
+ return tenSoftmax.cpu()
220
+
221
+
222
+ def aggregate_frames(frames, pairwise_flows=None, agg_flow=None):
223
+ if pairwise_flows is None:
224
+ # store pairwise flows
225
+ pairwise_flows = []
226
+
227
+ if agg_flow is None:
228
+ start_idx = 0
229
+ else:
230
+ start_idx = len(pairwise_flows)
231
+
232
+ og_image = frames[start_idx]
233
+ prev_frame = og_image
234
+
235
+ for i in range(start_idx, len(frames)-1):
236
+ tgt_frame = frames[i+1]
237
+
238
+ if i < len(pairwise_flows):
239
+ flow = pairwise_flows[i]
240
+ else:
241
+ flow = compute_flow_from_tensors(prev_frame, tgt_frame)
242
+ pairwise_flows.append(flow.clone())
243
+
244
+ _, H, W = flow.shape
245
+ B=1
246
+
247
+ xx = torch.arange(0, W).view(1,-1).repeat(H,1)
248
+
249
+ yy = torch.arange(0, H).view(-1,1).repeat(1,W)
250
+
251
+ xx = xx.view(1,1,H,W).repeat(B,1,1,1)
252
+
253
+ yy = yy.view(1,1,H,W).repeat(B,1,1,1)
254
+
255
+ grid = torch.cat((xx,yy),1).float()
256
+
257
+ flow = flow.unsqueeze(0)
258
+ if agg_flow is None:
259
+ agg_flow = torch.zeros_like(flow)
260
+
261
+ vgrid = grid + agg_flow
262
+ vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1) - 1
263
+
264
+ vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1) - 1
265
+
266
+ flow_out = torch.nn.functional.grid_sample(flow, vgrid.permute(0,2,3,1), 'nearest')
267
+
268
+ agg_flow += flow_out
269
+
270
+
271
+ # mask = forward_splt(torch.ones_like(og_image), torch.ones_like(og_image), agg_flow.squeeze()).squeeze()
272
+ # blur_t = torchvision.transforms.GaussianBlur(kernel_size=(25,25), sigma=5.0)
273
+ # warping_mask = (blur_t(mask)[0:1] > 0.8)
274
+ # masks.append(warping_mask)
275
+ prev_frame = tgt_frame
276
+
277
+ return agg_flow, pairwise_flows #og_splatted_img, agg_flow, actual_warped_mask
278
+
279
+
280
+ def forward_warp(src_frame, tgt_frame, flow, grid=None, alpha_mask=None):
281
+ if alpha_mask is None:
282
+ alpha_mask = torch.ones_like(src_frame[:1])
283
+
284
+ if grid is not None:
285
+ src_list = [src_frame, grid, alpha_mask]
286
+ tgt_list = [tgt_frame, grid, alpha_mask]
287
+ else:
288
+ src_list = [src_frame, alpha_mask]
289
+ tgt_list = [tgt_frame, alpha_mask]
290
+
291
+ og_image_padded = torch.concat(src_list, dim=0)
292
+ tgt_frame_padded = torch.concat(tgt_list, dim=0)
293
+
294
+ og_splatted_img = forward_splt(og_image_padded, tgt_frame_padded, flow.squeeze(), partial=True).squeeze()
295
+ # print('og splatted image shape')
296
+ # grid_transformed = og_splatted_img[3:-1]
297
+ # print('grid transformed shape', grid_transformed)
298
+
299
+ # grid *= grid_size
300
+ # grid_transformed *= grid_size
301
+ actual_warped_mask = og_splatted_img[-1:]
302
+ splatted_rgb_grid = og_splatted_img[:-1]
303
+
304
+ return splatted_rgb_grid, actual_warped_mask
environment.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: MagicFixup
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - python=3.8.5
7
+ - pip=20.3
8
+ - cudatoolkit=11.3
9
+ - pytorch=1.11.0
10
+ - torchvision=0.12.0
11
+ - numpy=1.19.2
12
+ - pip:
13
+ - albumentations==0.4.3
14
+ - diffusers
15
+ - bezier
16
+ - gradio
17
+ - opencv-python==4.1.2.30
18
+ - pudb==2019.2
19
+ - invisible-watermark
20
+ - imageio==2.9.0
21
+ - imageio-ffmpeg==0.4.2
22
+ - pytorch-lightning==2.0.0
23
+ - omegaconf==2.1.1
24
+ - test-tube>=0.7.5
25
+ - streamlit>=0.73.1
26
+ - einops==0.3.0
27
+ - torch-fidelity==0.3.0
28
+ - transformers==4.19.2
29
+ - torchmetrics==0.7.0
30
+ - kornia==0.6
31
+ - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
32
+ - -e git+https://github.com/openai/CLIP.git@main#egg=clip
33
+ - -e .
examples/dog_beach__edit__003.png ADDED
examples/dog_beach_og.png ADDED
examples/fox_drinking__edit__01.png ADDED
examples/fox_drinking__edit__02.png ADDED
examples/fox_drinking_og.png ADDED
examples/kingfisher__edit__001.png ADDED
examples/kingfisher_og.png ADDED
examples/log.csv ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fox_drinking_og.png,fox_drinking__edit__01.png
2
+ palm_tree_og.png,palm_tree__edit__01.png
3
+ kingfisher_og.png,kingfisher__edit__001.png
4
+ pipes_og.png,pipes__edit__01.png
5
+ dog_beach_og.png,dog_beach__edit__003.png
6
+ fox_drinking_og.png,fox_drinking__edit__02.png
examples/palm_tree__edit__01.png ADDED
examples/palm_tree_og.png ADDED
examples/pipes__edit__01.png ADDED
examples/pipes_og.png ADDED
ku.py DELETED
@@ -1 +0,0 @@
1
- jsj
 
 
ldm/data/__init__.py ADDED
File without changes
ldm/data/collage_dataset.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Adobe. All rights reserved.
2
+
3
+ import numpy as np
4
+ import torch
5
+ import matplotlib.pyplot as plt
6
+ import torchvision.transforms.functional as F
7
+ import glob
8
+ import torchvision
9
+ from PIL import Image
10
+ import time
11
+ import os
12
+ import tqdm
13
+ from torch.utils.data import Dataset
14
+ import pathlib
15
+ import cv2
16
+ from PIL import Image
17
+ import os
18
+ import json
19
+ import albumentations as A
20
+
21
+ def get_tensor(normalize=True, toTensor=True):
22
+ transform_list = []
23
+ if toTensor:
24
+ transform_list += [torchvision.transforms.ToTensor()]
25
+
26
+ if normalize:
27
+ # transform_list += [torchvision.transforms.Normalize((0.0, 0.0, 0.0),
28
+ # (10.0, 10.0, 10.0))]
29
+ transform_list += [torchvision.transforms.Normalize((0.5, 0.5, 0.5),
30
+ (0.5, 0.5, 0.5))]
31
+ return torchvision.transforms.Compose(transform_list)
32
+
33
+ def get_tensor_clip(normalize=True, toTensor=True):
34
+ transform_list = [torchvision.transforms.Resize((224,224))]
35
+ if toTensor:
36
+ transform_list += [torchvision.transforms.ToTensor()]
37
+
38
+ if normalize:
39
+ transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
40
+ (0.26862954, 0.26130258, 0.27577711))]
41
+ return torchvision.transforms.Compose(transform_list)
42
+
43
+ def get_tensor_dino(normalize=True, toTensor=True):
44
+ transform_list = [torchvision.transforms.Resize((224,224))]
45
+ if toTensor:
46
+ transform_list += [torchvision.transforms.ToTensor()]
47
+
48
+ if normalize:
49
+ transform_list += [lambda x: 255.0 * x[:3],
50
+ torchvision.transforms.Normalize(
51
+ mean=(123.675, 116.28, 103.53),
52
+ std=(58.395, 57.12, 57.375),
53
+ )]
54
+ return torchvision.transforms.Compose(transform_list)
55
+
56
+ def crawl_folders(folder_path):
57
+ # glob crawl
58
+ all_files = []
59
+ folders = glob.glob(f'{folder_path}/*')
60
+
61
+ for folder in folders:
62
+ src_paths = glob.glob(f'{folder}/src_*png')
63
+ all_files.extend(src_paths)
64
+ return all_files
65
+
66
+ def get_grid(size):
67
+ y = np.repeat(np.arange(size)[None, ...], size)
68
+ y = y.reshape(size, size)
69
+ x = y.transpose()
70
+ out = np.stack([y,x], -1)
71
+ return out
72
+
73
+
74
+ class CollageDataset(Dataset):
75
+ def __init__(self, split_files, image_size, embedding_type, warping_type, blur_warped=False):
76
+ self.size = image_size
77
+ # depends on the embedding type
78
+ if embedding_type == 'clip':
79
+ self.get_embedding_vector = get_tensor_clip()
80
+ elif embedding_type == 'dino':
81
+ self.get_embedding_vector = get_tensor_dino()
82
+ self.get_tensor = get_tensor()
83
+ self.resize = torchvision.transforms.Resize(size=(image_size, image_size))
84
+ self.to_mask_tensor = get_tensor(normalize=False)
85
+
86
+ self.src_paths = crawl_folders(split_files)
87
+ print('current split size', len(self.src_paths))
88
+ print('for dir', split_files)
89
+
90
+ assert warping_type in ['collage', 'flow', 'mix']
91
+ self.warping_type = warping_type
92
+
93
+ self.mask_threshold = 0.85
94
+
95
+ self.blur_t = torchvision.transforms.GaussianBlur(kernel_size=51, sigma=20.0)
96
+ self.blur_warped = blur_warped
97
+
98
+ # self.save_folder = '/mnt/localssd/collage_out'
99
+ # os.makedirs(self.save_folder, exist_ok=True)
100
+ self.save_counter = 0
101
+ self.save_subfolder = None
102
+
103
+ def __len__(self):
104
+ return len(self.src_paths)
105
+
106
+
107
+ def __getitem__(self, idx, depth=0):
108
+
109
+ if self.warping_type == 'mix':
110
+ # randomly sample
111
+ warping_type = np.random.choice(['collage', 'flow'])
112
+ else:
113
+ warping_type = self.warping_type
114
+
115
+ src_path = self.src_paths[idx]
116
+ tgt_path = src_path.replace('src_', 'tgt_')
117
+
118
+ if warping_type == 'collage':
119
+ warped_path = src_path.replace('src_', 'composite_')
120
+ mask_path = src_path.replace('src_', 'composite_mask_')
121
+ corresp_path = src_path.replace('src_', 'composite_grid_')
122
+ corresp_path = corresp_path.split('.')[0]
123
+ corresp_path += '.npy'
124
+ elif warping_type == 'flow':
125
+ warped_path = src_path.replace('src_', 'flow_warped_')
126
+ mask_path = src_path.replace('src_', 'flow_mask_')
127
+ corresp_path = src_path.replace('src_', 'flow_warped_grid_')
128
+ corresp_path = corresp_path.split('.')[0]
129
+ corresp_path += '.npy'
130
+ else:
131
+ raise ValueError
132
+
133
+ # load reference image, warped image, and target GT image
134
+ reference_img = Image.open(src_path).convert('RGB')
135
+ gt_img = Image.open(tgt_path).convert('RGB')
136
+ warped_img = Image.open(warped_path).convert('RGB')
137
+ warping_mask = Image.open(mask_path).convert('RGB')
138
+
139
+ # resize all
140
+ reference_img = self.resize(reference_img)
141
+ gt_img = self.resize(gt_img)
142
+ warped_img = self.resize(warped_img)
143
+ warping_mask = self.resize(warping_mask)
144
+
145
+
146
+ # NO CROPPING PLEASE. ALL INPUTS ARE 512X512
147
+ # Random crop
148
+ # i, j, h, w = torchvision.transforms.RandomCrop.get_params(
149
+ # reference_img, output_size=(512, 512))
150
+
151
+ # reference_img = torchvision.transforms.functional.crop(reference_img, i, j, h, w)
152
+ # gt_img = torchvision.transforms.functional.crop(gt_img, i, j, h, w)
153
+ # warped_img = torchvision.transforms.functional.crop(warped_img, i, j, h, w)
154
+ # # TODO start using the warping mask
155
+ # warping_mask = torchvision.transforms.functional.crop(warping_mask, i, j, h, w)
156
+
157
+ grid_transformed = torch.tensor(np.load(corresp_path))
158
+ # grid_transformed = torchvision.transforms.functional.crop(grid_transformed, i, j, h, w)
159
+
160
+
161
+
162
+ # reference_t = to_tensor(reference_img)
163
+ gt_t = self.get_tensor(gt_img)
164
+ warped_t = self.get_tensor(warped_img)
165
+ warping_mask_t = self.to_mask_tensor(warping_mask)
166
+ clean_reference_t = self.get_tensor(reference_img)
167
+ # compute error to generate mask
168
+ blur_t = torchvision.transforms.GaussianBlur(kernel_size=(11,11), sigma=5.0)
169
+
170
+ reference_clip_img = self.get_embedding_vector(reference_img)
171
+
172
+ mask = torch.ones_like(gt_t)[:1]
173
+ warping_mask_t = warping_mask_t[:1]
174
+
175
+ good_region = torch.mean(warping_mask_t)
176
+ # print('good region', good_region)
177
+ # print('good region frac', good_region)
178
+ if good_region < 0.4 and depth < 3:
179
+ # example too hard, sample something else
180
+ # print('bad image, resampling..')
181
+ rand_idx = np.random.randint(len(self.src_paths))
182
+ return self.__getitem__(rand_idx, depth+1)
183
+
184
+ # if mask is too large then ignore
185
+
186
+ # #gaussian inpainting now
187
+ missing_mask = warping_mask_t[0] < 0.5
188
+
189
+
190
+ reference = (warped_t.clone() + 1) / 2.0
191
+ ref_cv = torch.moveaxis(reference, 0, -1).cpu().numpy()
192
+ ref_cv = (ref_cv * 255).astype(np.uint8)
193
+ cv_mask = missing_mask.int().squeeze().cpu().numpy().astype(np.uint8)
194
+ kernel = np.ones((7,7))
195
+ dilated_mask = cv2.dilate(cv_mask, kernel)
196
+ # cv_mask = np.stack([cv_mask]*3, axis=-1)
197
+ dst = cv2.inpaint(ref_cv,dilated_mask,5,cv2.INPAINT_NS)
198
+
199
+ mask_resized = torchvision.transforms.functional.resize(warping_mask_t, (64,64))
200
+ # print(mask_resized)
201
+ size=512
202
+ grid_np = (get_grid(size) / size).astype(np.float16)# 512 x 512 x 2
203
+ grid_t = torch.tensor(grid_np).moveaxis(-1, 0) # 512 x 512 x 2
204
+ grid_resized = torchvision.transforms.functional.resize(grid_t, (64,64)).to(torch.float16)
205
+ changed_pixels = torch.logical_or((torch.abs(grid_resized - grid_transformed)[0] > 0.04) , (torch.abs(grid_resized - grid_transformed)[1] > 0.04))
206
+ changed_pixels = changed_pixels.unsqueeze(0)
207
+ # changed_pixels = torch.logical_and(changed_pixels, (mask_resized >= 0.3))
208
+ changed_pixels = changed_pixels.float()
209
+
210
+ inpainted_warped = (torch.tensor(dst).moveaxis(-1, 0).float() / 255.0) * 2.0 - 1.0
211
+
212
+ if self.blur_warped:
213
+ inpainted_warped= self.blur_t(inpainted_warped)
214
+
215
+ out = {"GT": gt_t,"inpaint_image": inpainted_warped,"inpaint_mask": warping_mask_t, "ref_imgs": reference_clip_img, "clean_reference": clean_reference_t, 'grid_transformed': grid_transformed, "changed_pixels": changed_pixels}
216
+ # out = {"GT": gt_t,"inpaint_image": inpainted_warped * 0.0,"inpaint_mask": torch.ones_like(warping_mask_t), "ref_imgs": reference_clip_img * 0.0, "clean_reference": gt_t, 'grid_transformed': grid_transformed, "changed_pixels": changed_pixels}
217
+ # out = {"GT": gt_t,"inpaint_image": inpainted_warped * 0.0,"inpaint_mask": warping_mask_t, "ref_imgs": reference_clip_img * 0.0, "clean_reference": clean_reference_t, 'grid_transformed': grid_transformed, "changed_pixels": changed_pixels}
218
+
219
+ # out = {"GT": gt_t,"inpaint_image": warped_t,"inpaint_mask": warping_mask_t, "ref_imgs": reference_clip_img, "clean_reference": clean_reference_t, 'grid_transformed': grid_transformed, 'inpainted': inpainted_warped}
220
+ # out_half = {key: out[key].half() for key in out}
221
+ # if self.save_counter < 50:
222
+ # save_path = f'{self.save_folder}/output_{time.time()}.pt'
223
+ # torch.save(out, save_path)
224
+ # self.save_counter += 1
225
+
226
+ return out
227
+
228
+
229
+
230
+
ldm/lr_scheduler.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ import numpy as np
15
+
16
+
17
+ class LambdaWarmUpCosineScheduler:
18
+ """
19
+ note: use with a base_lr of 1.0
20
+ """
21
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
22
+ self.lr_warm_up_steps = warm_up_steps
23
+ self.lr_start = lr_start
24
+ self.lr_min = lr_min
25
+ self.lr_max = lr_max
26
+ self.lr_max_decay_steps = max_decay_steps
27
+ self.last_lr = 0.
28
+ self.verbosity_interval = verbosity_interval
29
+
30
+ def schedule(self, n, **kwargs):
31
+ if self.verbosity_interval > 0:
32
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
33
+ if n < self.lr_warm_up_steps:
34
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
35
+ self.last_lr = lr
36
+ return lr
37
+ else:
38
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
39
+ t = min(t, 1.0)
40
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
41
+ 1 + np.cos(t * np.pi))
42
+ self.last_lr = lr
43
+ return lr
44
+
45
+ def __call__(self, n, **kwargs):
46
+ return self.schedule(n,**kwargs)
47
+
48
+
49
+ class LambdaWarmUpCosineScheduler2:
50
+ """
51
+ supports repeated iterations, configurable via lists
52
+ note: use with a base_lr of 1.0.
53
+ """
54
+ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
55
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
56
+ self.lr_warm_up_steps = warm_up_steps
57
+ self.f_start = f_start
58
+ self.f_min = f_min
59
+ self.f_max = f_max
60
+ self.cycle_lengths = cycle_lengths
61
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
62
+ self.last_f = 0.
63
+ self.verbosity_interval = verbosity_interval
64
+
65
+ def find_in_interval(self, n):
66
+ interval = 0
67
+ for cl in self.cum_cycles[1:]:
68
+ if n <= cl:
69
+ return interval
70
+ interval += 1
71
+
72
+ def schedule(self, n, **kwargs):
73
+ cycle = self.find_in_interval(n)
74
+ n = n - self.cum_cycles[cycle]
75
+ if self.verbosity_interval > 0:
76
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
77
+ f"current cycle {cycle}")
78
+ if n < self.lr_warm_up_steps[cycle]:
79
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
80
+ self.last_f = f
81
+ return f
82
+ else:
83
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
84
+ t = min(t, 1.0)
85
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
86
+ 1 + np.cos(t * np.pi))
87
+ self.last_f = f
88
+ return f
89
+
90
+ def __call__(self, n, **kwargs):
91
+ return self.schedule(n, **kwargs)
92
+
93
+
94
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
95
+
96
+ def schedule(self, n, **kwargs):
97
+ cycle = self.find_in_interval(n)
98
+ n = n - self.cum_cycles[cycle]
99
+ if self.verbosity_interval > 0:
100
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
101
+ f"current cycle {cycle}")
102
+
103
+ if n < self.lr_warm_up_steps[cycle]:
104
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
105
+ self.last_f = f
106
+ return f
107
+ else:
108
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
109
+ self.last_f = f
110
+ return f
111
+
ldm/models/autoencoder.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ import torch
15
+ import pytorch_lightning as pl
16
+ import torch.nn.functional as F
17
+ from contextlib import contextmanager
18
+
19
+ from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
20
+
21
+ from ldm.modules.diffusionmodules.model import Encoder, Decoder
22
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
23
+
24
+ from ldm.util import instantiate_from_config
25
+
26
+
27
+ class VQModel(pl.LightningModule):
28
+ def __init__(self,
29
+ ddconfig,
30
+ lossconfig,
31
+ n_embed,
32
+ embed_dim,
33
+ ckpt_path=None,
34
+ ignore_keys=[],
35
+ image_key="image",
36
+ colorize_nlabels=None,
37
+ monitor=None,
38
+ batch_resize_range=None,
39
+ scheduler_config=None,
40
+ lr_g_factor=1.0,
41
+ remap=None,
42
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
43
+ use_ema=False
44
+ ):
45
+ super().__init__()
46
+ self.embed_dim = embed_dim
47
+ self.n_embed = n_embed
48
+ self.image_key = image_key
49
+ self.encoder = Encoder(**ddconfig)
50
+ self.decoder = Decoder(**ddconfig)
51
+ self.loss = instantiate_from_config(lossconfig)
52
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
53
+ remap=remap,
54
+ sane_index_shape=sane_index_shape)
55
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
56
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
57
+ if colorize_nlabels is not None:
58
+ assert type(colorize_nlabels)==int
59
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
60
+ if monitor is not None:
61
+ self.monitor = monitor
62
+ self.batch_resize_range = batch_resize_range
63
+ if self.batch_resize_range is not None:
64
+ print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
65
+
66
+ self.use_ema = use_ema
67
+ if self.use_ema:
68
+ self.model_ema = LitEma(self)
69
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
70
+
71
+ if ckpt_path is not None:
72
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
73
+ self.scheduler_config = scheduler_config
74
+ self.lr_g_factor = lr_g_factor
75
+
76
+ @contextmanager
77
+ def ema_scope(self, context=None):
78
+ if self.use_ema:
79
+ self.model_ema.store(self.parameters())
80
+ self.model_ema.copy_to(self)
81
+ if context is not None:
82
+ print(f"{context}: Switched to EMA weights")
83
+ try:
84
+ yield None
85
+ finally:
86
+ if self.use_ema:
87
+ self.model_ema.restore(self.parameters())
88
+ if context is not None:
89
+ print(f"{context}: Restored training weights")
90
+
91
+ def init_from_ckpt(self, path, ignore_keys=list()):
92
+ sd = torch.load(path, map_location="cpu")["state_dict"]
93
+ keys = list(sd.keys())
94
+ for k in keys:
95
+ for ik in ignore_keys:
96
+ if k.startswith(ik):
97
+ print("Deleting key {} from state_dict.".format(k))
98
+ del sd[k]
99
+ missing, unexpected = self.load_state_dict(sd, strict=False)
100
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
101
+ if len(missing) > 0:
102
+ print(f"Missing Keys: {missing}")
103
+ print(f"Unexpected Keys: {unexpected}")
104
+
105
+ def on_train_batch_end(self, *args, **kwargs):
106
+ if self.use_ema:
107
+ self.model_ema(self)
108
+
109
+ def encode(self, x):
110
+ h = self.encoder(x)
111
+ h = self.quant_conv(h)
112
+ quant, emb_loss, info = self.quantize(h)
113
+ return quant, emb_loss, info
114
+
115
+ def encode_to_prequant(self, x):
116
+ h = self.encoder(x)
117
+ h = self.quant_conv(h)
118
+ return h
119
+
120
+ def decode(self, quant):
121
+ quant = self.post_quant_conv(quant)
122
+ dec = self.decoder(quant)
123
+ return dec
124
+
125
+ def decode_code(self, code_b):
126
+ quant_b = self.quantize.embed_code(code_b)
127
+ dec = self.decode(quant_b)
128
+ return dec
129
+
130
+ def forward(self, input, return_pred_indices=False):
131
+ quant, diff, (_,_,ind) = self.encode(input)
132
+ dec = self.decode(quant)
133
+ if return_pred_indices:
134
+ return dec, diff, ind
135
+ return dec, diff
136
+
137
+ def get_input(self, batch, k):
138
+ x = batch[k]
139
+ if len(x.shape) == 3:
140
+ x = x[..., None]
141
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
142
+ if self.batch_resize_range is not None:
143
+ lower_size = self.batch_resize_range[0]
144
+ upper_size = self.batch_resize_range[1]
145
+ if self.global_step <= 4:
146
+ # do the first few batches with max size to avoid later oom
147
+ new_resize = upper_size
148
+ else:
149
+ new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
150
+ if new_resize != x.shape[2]:
151
+ x = F.interpolate(x, size=new_resize, mode="bicubic")
152
+ x = x.detach()
153
+ return x
154
+
155
+ def training_step(self, batch, batch_idx, optimizer_idx):
156
+ # https://github.com/pytorch/pytorch/issues/37142
157
+ # try not to fool the heuristics
158
+ x = self.get_input(batch, self.image_key)
159
+ xrec, qloss, ind = self(x, return_pred_indices=True)
160
+
161
+ if optimizer_idx == 0:
162
+ # autoencode
163
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
164
+ last_layer=self.get_last_layer(), split="train",
165
+ predicted_indices=ind)
166
+
167
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
168
+ return aeloss
169
+
170
+ if optimizer_idx == 1:
171
+ # discriminator
172
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
173
+ last_layer=self.get_last_layer(), split="train")
174
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
175
+ return discloss
176
+
177
+ def validation_step(self, batch, batch_idx):
178
+ log_dict = self._validation_step(batch, batch_idx)
179
+ with self.ema_scope():
180
+ log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
181
+ return log_dict
182
+
183
+ def _validation_step(self, batch, batch_idx, suffix=""):
184
+ x = self.get_input(batch, self.image_key)
185
+ xrec, qloss, ind = self(x, return_pred_indices=True)
186
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
187
+ self.global_step,
188
+ last_layer=self.get_last_layer(),
189
+ split="val"+suffix,
190
+ predicted_indices=ind
191
+ )
192
+
193
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
194
+ self.global_step,
195
+ last_layer=self.get_last_layer(),
196
+ split="val"+suffix,
197
+ predicted_indices=ind
198
+ )
199
+ rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
200
+ self.log(f"val{suffix}/rec_loss", rec_loss,
201
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
202
+ self.log(f"val{suffix}/aeloss", aeloss,
203
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
204
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
205
+ del log_dict_ae[f"val{suffix}/rec_loss"]
206
+ self.log_dict(log_dict_ae)
207
+ self.log_dict(log_dict_disc)
208
+ return self.log_dict
209
+
210
+ def configure_optimizers(self):
211
+ lr_d = self.learning_rate
212
+ lr_g = self.lr_g_factor*self.learning_rate
213
+ print("lr_d", lr_d)
214
+ print("lr_g", lr_g)
215
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
216
+ list(self.decoder.parameters())+
217
+ list(self.quantize.parameters())+
218
+ list(self.quant_conv.parameters())+
219
+ list(self.post_quant_conv.parameters()),
220
+ lr=lr_g, betas=(0.5, 0.9))
221
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
222
+ lr=lr_d, betas=(0.5, 0.9))
223
+
224
+ if self.scheduler_config is not None:
225
+ scheduler = instantiate_from_config(self.scheduler_config)
226
+
227
+ print("Setting up LambdaLR scheduler...")
228
+ scheduler = [
229
+ {
230
+ 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
231
+ 'interval': 'step',
232
+ 'frequency': 1
233
+ },
234
+ {
235
+ 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
236
+ 'interval': 'step',
237
+ 'frequency': 1
238
+ },
239
+ ]
240
+ return [opt_ae, opt_disc], scheduler
241
+ return [opt_ae, opt_disc], []
242
+
243
+ def get_last_layer(self):
244
+ return self.decoder.conv_out.weight
245
+
246
+ def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
247
+ log = dict()
248
+ x = self.get_input(batch, self.image_key)
249
+ x = x.to(self.device)
250
+ if only_inputs:
251
+ log["inputs"] = x
252
+ return log
253
+ xrec, _ = self(x)
254
+ if x.shape[1] > 3:
255
+ # colorize with random projection
256
+ assert xrec.shape[1] > 3
257
+ x = self.to_rgb(x)
258
+ xrec = self.to_rgb(xrec)
259
+ log["inputs"] = x
260
+ log["reconstructions"] = xrec
261
+ if plot_ema:
262
+ with self.ema_scope():
263
+ xrec_ema, _ = self(x)
264
+ if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
265
+ log["reconstructions_ema"] = xrec_ema
266
+ return log
267
+
268
+ def to_rgb(self, x):
269
+ assert self.image_key == "segmentation"
270
+ if not hasattr(self, "colorize"):
271
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
272
+ x = F.conv2d(x, weight=self.colorize)
273
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
274
+ return x
275
+
276
+
277
+ class VQModelInterface(VQModel):
278
+ def __init__(self, embed_dim, *args, **kwargs):
279
+ super().__init__(embed_dim=embed_dim, *args, **kwargs)
280
+ self.embed_dim = embed_dim
281
+
282
+ def encode(self, x):
283
+ h = self.encoder(x)
284
+ h = self.quant_conv(h)
285
+ return h
286
+
287
+ def decode(self, h, force_not_quantize=False):
288
+ # also go through quantization layer
289
+ if not force_not_quantize:
290
+ quant, emb_loss, info = self.quantize(h)
291
+ else:
292
+ quant = h
293
+ quant = self.post_quant_conv(quant)
294
+ dec = self.decoder(quant)
295
+ return dec
296
+
297
+
298
+ class AutoencoderKL(pl.LightningModule):
299
+ def __init__(self,
300
+ ddconfig,
301
+ lossconfig,
302
+ embed_dim,
303
+ ckpt_path=None,
304
+ ignore_keys=[],
305
+ image_key="image",
306
+ colorize_nlabels=None,
307
+ monitor=None,
308
+ ):
309
+ super().__init__()
310
+ self.image_key = image_key
311
+ self.encoder = Encoder(**ddconfig)
312
+ self.decoder = Decoder(**ddconfig)
313
+ self.loss = instantiate_from_config(lossconfig)
314
+ assert ddconfig["double_z"]
315
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
316
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
317
+ self.embed_dim = embed_dim
318
+ if colorize_nlabels is not None:
319
+ assert type(colorize_nlabels)==int
320
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
321
+ if monitor is not None:
322
+ self.monitor = monitor
323
+ if ckpt_path is not None:
324
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
325
+
326
+ def init_from_ckpt(self, path, ignore_keys=list()):
327
+ sd = torch.load(path, map_location="cpu")["state_dict"]
328
+ keys = list(sd.keys())
329
+ for k in keys:
330
+ for ik in ignore_keys:
331
+ if k.startswith(ik):
332
+ print("Deleting key {} from state_dict.".format(k))
333
+ del sd[k]
334
+ self.load_state_dict(sd, strict=False)
335
+ print(f"Restored from {path}")
336
+
337
+ def encode(self, x):
338
+ h = self.encoder(x)
339
+ moments = self.quant_conv(h)
340
+ posterior = DiagonalGaussianDistribution(moments)
341
+ return posterior
342
+
343
+ def decode(self, z):
344
+ z = self.post_quant_conv(z)
345
+ dec = self.decoder(z)
346
+ return dec
347
+
348
+ def forward(self, input, sample_posterior=True):
349
+ posterior = self.encode(input)
350
+ if sample_posterior:
351
+ z = posterior.sample()
352
+ else:
353
+ z = posterior.mode()
354
+ dec = self.decode(z)
355
+ return dec, posterior
356
+
357
+ def get_input(self, batch, k):
358
+ x = batch[k]
359
+ if len(x.shape) == 3:
360
+ x = x[..., None]
361
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
362
+ return x
363
+
364
+ def training_step(self, batch, batch_idx, optimizer_idx):
365
+ inputs = self.get_input(batch, self.image_key)
366
+ reconstructions, posterior = self(inputs)
367
+
368
+ if optimizer_idx == 0:
369
+ # train encoder+decoder+logvar
370
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
371
+ last_layer=self.get_last_layer(), split="train")
372
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
373
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
374
+ return aeloss
375
+
376
+ if optimizer_idx == 1:
377
+ # train the discriminator
378
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
379
+ last_layer=self.get_last_layer(), split="train")
380
+
381
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
382
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
383
+ return discloss
384
+
385
+ def validation_step(self, batch, batch_idx):
386
+ inputs = self.get_input(batch, self.image_key)
387
+ reconstructions, posterior = self(inputs)
388
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
389
+ last_layer=self.get_last_layer(), split="val")
390
+
391
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
392
+ last_layer=self.get_last_layer(), split="val")
393
+
394
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
395
+ self.log_dict(log_dict_ae)
396
+ self.log_dict(log_dict_disc)
397
+ return self.log_dict
398
+
399
+ def configure_optimizers(self):
400
+ lr = self.learning_rate
401
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
402
+ list(self.decoder.parameters())+
403
+ list(self.quant_conv.parameters())+
404
+ list(self.post_quant_conv.parameters()),
405
+ lr=lr, betas=(0.5, 0.9))
406
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
407
+ lr=lr, betas=(0.5, 0.9))
408
+ return [opt_ae, opt_disc], []
409
+
410
+ def get_last_layer(self):
411
+ return self.decoder.conv_out.weight
412
+
413
+ @torch.no_grad()
414
+ def log_images(self, batch, only_inputs=False, **kwargs):
415
+ log = dict()
416
+ x = self.get_input(batch, self.image_key)
417
+ x = x.to(self.device)
418
+ if not only_inputs:
419
+ xrec, posterior = self(x)
420
+ if x.shape[1] > 3:
421
+ # colorize with random projection
422
+ assert xrec.shape[1] > 3
423
+ x = self.to_rgb(x)
424
+ xrec = self.to_rgb(xrec)
425
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
426
+ log["reconstructions"] = xrec
427
+ log["inputs"] = x
428
+ return log
429
+
430
+ def to_rgb(self, x):
431
+ assert self.image_key == "segmentation"
432
+ if not hasattr(self, "colorize"):
433
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
434
+ x = F.conv2d(x, weight=self.colorize)
435
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
436
+ return x
437
+
438
+
439
+ class IdentityFirstStage(torch.nn.Module):
440
+ def __init__(self, *args, vq_interface=False, **kwargs):
441
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
442
+ super().__init__()
443
+
444
+ def encode(self, x, *args, **kwargs):
445
+ return x
446
+
447
+ def decode(self, x, *args, **kwargs):
448
+ return x
449
+
450
+ def quantize(self, x, *args, **kwargs):
451
+ if self.vq_interface:
452
+ return x, None, [None, None, None]
453
+ return x
454
+
455
+ def forward(self, x, *args, **kwargs):
456
+ return x
ldm/models/diffusion/__init__.py ADDED
File without changes
ldm/models/diffusion/classifier.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ import os
15
+ import torch
16
+ import pytorch_lightning as pl
17
+ from omegaconf import OmegaConf
18
+ from torch.nn import functional as F
19
+ from torch.optim import AdamW
20
+ from torch.optim.lr_scheduler import LambdaLR
21
+ from copy import deepcopy
22
+ from einops import rearrange
23
+ from glob import glob
24
+ from natsort import natsorted
25
+
26
+ from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
27
+ from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
28
+
29
+ __models__ = {
30
+ 'class_label': EncoderUNetModel,
31
+ 'segmentation': UNetModel
32
+ }
33
+
34
+
35
+ def disabled_train(self, mode=True):
36
+ """Overwrite model.train with this function to make sure train/eval mode
37
+ does not change anymore."""
38
+ return self
39
+
40
+
41
+ class NoisyLatentImageClassifier(pl.LightningModule):
42
+
43
+ def __init__(self,
44
+ diffusion_path,
45
+ num_classes,
46
+ ckpt_path=None,
47
+ pool='attention',
48
+ label_key=None,
49
+ diffusion_ckpt_path=None,
50
+ scheduler_config=None,
51
+ weight_decay=1.e-2,
52
+ log_steps=10,
53
+ monitor='val/loss',
54
+ *args,
55
+ **kwargs):
56
+ super().__init__(*args, **kwargs)
57
+ self.num_classes = num_classes
58
+ # get latest config of diffusion model
59
+ diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
60
+ self.diffusion_config = OmegaConf.load(diffusion_config).model
61
+ self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
62
+ self.load_diffusion()
63
+
64
+ self.monitor = monitor
65
+ self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
66
+ self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
67
+ self.log_steps = log_steps
68
+
69
+ self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
70
+ else self.diffusion_model.cond_stage_key
71
+
72
+ assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
73
+
74
+ if self.label_key not in __models__:
75
+ raise NotImplementedError()
76
+
77
+ self.load_classifier(ckpt_path, pool)
78
+
79
+ self.scheduler_config = scheduler_config
80
+ self.use_scheduler = self.scheduler_config is not None
81
+ self.weight_decay = weight_decay
82
+
83
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
84
+ sd = torch.load(path, map_location="cpu")
85
+ if "state_dict" in list(sd.keys()):
86
+ sd = sd["state_dict"]
87
+ keys = list(sd.keys())
88
+ for k in keys:
89
+ for ik in ignore_keys:
90
+ if k.startswith(ik):
91
+ print("Deleting key {} from state_dict.".format(k))
92
+ del sd[k]
93
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
94
+ sd, strict=False)
95
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
96
+ if len(missing) > 0:
97
+ print(f"Missing Keys: {missing}")
98
+ if len(unexpected) > 0:
99
+ print(f"Unexpected Keys: {unexpected}")
100
+
101
+ def load_diffusion(self):
102
+ model = instantiate_from_config(self.diffusion_config)
103
+ self.diffusion_model = model.eval()
104
+ self.diffusion_model.train = disabled_train
105
+ for param in self.diffusion_model.parameters():
106
+ param.requires_grad = False
107
+
108
+ def load_classifier(self, ckpt_path, pool):
109
+ model_config = deepcopy(self.diffusion_config.params.unet_config.params)
110
+ model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
111
+ model_config.out_channels = self.num_classes
112
+ if self.label_key == 'class_label':
113
+ model_config.pool = pool
114
+
115
+ self.model = __models__[self.label_key](**model_config)
116
+ if ckpt_path is not None:
117
+ print('#####################################################################')
118
+ print(f'load from ckpt "{ckpt_path}"')
119
+ print('#####################################################################')
120
+ self.init_from_ckpt(ckpt_path)
121
+
122
+ @torch.no_grad()
123
+ def get_x_noisy(self, x, t, noise=None):
124
+ noise = default(noise, lambda: torch.randn_like(x))
125
+ continuous_sqrt_alpha_cumprod = None
126
+ if self.diffusion_model.use_continuous_noise:
127
+ continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
128
+ # todo: make sure t+1 is correct here
129
+
130
+ return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
131
+ continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
132
+
133
+ def forward(self, x_noisy, t, *args, **kwargs):
134
+ return self.model(x_noisy, t)
135
+
136
+ @torch.no_grad()
137
+ def get_input(self, batch, k):
138
+ x = batch[k]
139
+ if len(x.shape) == 3:
140
+ x = x[..., None]
141
+ x = rearrange(x, 'b h w c -> b c h w')
142
+ x = x.to(memory_format=torch.contiguous_format).float()
143
+ return x
144
+
145
+ @torch.no_grad()
146
+ def get_conditioning(self, batch, k=None):
147
+ if k is None:
148
+ k = self.label_key
149
+ assert k is not None, 'Needs to provide label key'
150
+
151
+ targets = batch[k].to(self.device)
152
+
153
+ if self.label_key == 'segmentation':
154
+ targets = rearrange(targets, 'b h w c -> b c h w')
155
+ for down in range(self.numd):
156
+ h, w = targets.shape[-2:]
157
+ targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
158
+
159
+ # targets = rearrange(targets,'b c h w -> b h w c')
160
+
161
+ return targets
162
+
163
+ def compute_top_k(self, logits, labels, k, reduction="mean"):
164
+ _, top_ks = torch.topk(logits, k, dim=1)
165
+ if reduction == "mean":
166
+ return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
167
+ elif reduction == "none":
168
+ return (top_ks == labels[:, None]).float().sum(dim=-1)
169
+
170
+ def on_train_epoch_start(self):
171
+ # save some memory
172
+ self.diffusion_model.model.to('cpu')
173
+
174
+ @torch.no_grad()
175
+ def write_logs(self, loss, logits, targets):
176
+ log_prefix = 'train' if self.training else 'val'
177
+ log = {}
178
+ log[f"{log_prefix}/loss"] = loss.mean()
179
+ log[f"{log_prefix}/acc@1"] = self.compute_top_k(
180
+ logits, targets, k=1, reduction="mean"
181
+ )
182
+ log[f"{log_prefix}/acc@5"] = self.compute_top_k(
183
+ logits, targets, k=5, reduction="mean"
184
+ )
185
+
186
+ self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
187
+ self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
188
+ self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
189
+ lr = self.optimizers().param_groups[0]['lr']
190
+ self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
191
+
192
+ def shared_step(self, batch, t=None):
193
+ x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
194
+ targets = self.get_conditioning(batch)
195
+ if targets.dim() == 4:
196
+ targets = targets.argmax(dim=1)
197
+ if t is None:
198
+ t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
199
+ else:
200
+ t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
201
+ x_noisy = self.get_x_noisy(x, t)
202
+ logits = self(x_noisy, t)
203
+
204
+ loss = F.cross_entropy(logits, targets, reduction='none')
205
+
206
+ self.write_logs(loss.detach(), logits.detach(), targets.detach())
207
+
208
+ loss = loss.mean()
209
+ return loss, logits, x_noisy, targets
210
+
211
+ def training_step(self, batch, batch_idx):
212
+ loss, *_ = self.shared_step(batch)
213
+ return loss
214
+
215
+ def reset_noise_accs(self):
216
+ self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
217
+ range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
218
+
219
+ def on_validation_start(self):
220
+ self.reset_noise_accs()
221
+
222
+ @torch.no_grad()
223
+ def validation_step(self, batch, batch_idx):
224
+ loss, *_ = self.shared_step(batch)
225
+
226
+ for t in self.noisy_acc:
227
+ _, logits, _, targets = self.shared_step(batch, t)
228
+ self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
229
+ self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
230
+
231
+ return loss
232
+
233
+ def configure_optimizers(self):
234
+ optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
235
+
236
+ if self.use_scheduler:
237
+ scheduler = instantiate_from_config(self.scheduler_config)
238
+
239
+ print("Setting up LambdaLR scheduler...")
240
+ scheduler = [
241
+ {
242
+ 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
243
+ 'interval': 'step',
244
+ 'frequency': 1
245
+ }]
246
+ return [optimizer], scheduler
247
+
248
+ return optimizer
249
+
250
+ @torch.no_grad()
251
+ def log_images(self, batch, N=8, *args, **kwargs):
252
+ log = dict()
253
+ x = self.get_input(batch, self.diffusion_model.first_stage_key)
254
+ log['inputs'] = x
255
+
256
+ y = self.get_conditioning(batch)
257
+
258
+ if self.label_key == 'class_label':
259
+ y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
260
+ log['labels'] = y
261
+
262
+ if ismap(y):
263
+ log['labels'] = self.diffusion_model.to_rgb(y)
264
+
265
+ for step in range(self.log_steps):
266
+ current_time = step * self.log_time_interval
267
+
268
+ _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
269
+
270
+ log[f'inputs@t{current_time}'] = x_noisy
271
+
272
+ pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
273
+ pred = rearrange(pred, 'b h w c -> b c h w')
274
+
275
+ log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
276
+
277
+ for key in log:
278
+ log[key] = log[key][:N]
279
+
280
+ return log
ldm/models/diffusion/ddim.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ """SAMPLING ONLY."""
15
+
16
+ import torch
17
+ import numpy as np
18
+ from tqdm import tqdm
19
+ from functools import partial
20
+
21
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
22
+ extract_into_tensor
23
+
24
+
25
+ class DDIMSampler(object):
26
+ def __init__(self, model, schedule="linear", **kwargs):
27
+ super().__init__()
28
+ self.model = model
29
+ self.ddpm_num_timesteps = model.num_timesteps
30
+ self.schedule = schedule
31
+
32
+ def register_buffer(self, name, attr):
33
+ if type(attr) == torch.Tensor:
34
+ if attr.device != torch.device("cuda"):
35
+ attr = attr.to(torch.device("cuda"))
36
+ setattr(self, name, attr)
37
+
38
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True, steps=None):
39
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
40
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose, steps=steps)
41
+ alphas_cumprod = self.model.alphas_cumprod
42
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
43
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
44
+
45
+ self.register_buffer('betas', to_torch(self.model.betas))
46
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
47
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
48
+
49
+ # calculations for diffusion q(x_t | x_{t-1}) and others
50
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
51
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
52
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
53
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
54
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
55
+
56
+ # ddim sampling parameters
57
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
58
+ ddim_timesteps=self.ddim_timesteps,
59
+ eta=ddim_eta,verbose=verbose)
60
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
61
+ self.register_buffer('ddim_alphas', ddim_alphas)
62
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
63
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
64
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
65
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
66
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
67
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
68
+
69
+ @torch.no_grad()
70
+ def sample(self,
71
+ S,
72
+ batch_size,
73
+ shape,
74
+ conditioning=None,
75
+ callback=None,
76
+ normals_sequence=None,
77
+ img_callback=None,
78
+ quantize_x0=False,
79
+ eta=0.,
80
+ mask=None,
81
+ x0=None,
82
+ temperature=1.,
83
+ noise_dropout=0.,
84
+ score_corrector=None,
85
+ corrector_kwargs=None,
86
+ verbose=True,
87
+ x_T=None,
88
+ log_every_t=100,
89
+ unconditional_guidance_scale=1.,
90
+ unconditional_conditioning=None,
91
+ z_ref=None,
92
+ ddim_discretize='uniform',
93
+ schedule_steps=None,
94
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
95
+ **kwargs
96
+ ):
97
+ if conditioning is not None:
98
+ if isinstance(conditioning, dict):
99
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
100
+ if cbs != batch_size:
101
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
102
+ else:
103
+ if conditioning.shape[0] != batch_size:
104
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
105
+
106
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose, ddim_discretize=ddim_discretize, steps=schedule_steps)
107
+ # sampling
108
+ C, H, W = shape
109
+ size = (batch_size, C, H, W)
110
+
111
+ samples, intermediates = self.ddim_sampling(conditioning, size,
112
+ callback=callback,
113
+ img_callback=img_callback,
114
+ quantize_denoised=quantize_x0,
115
+ mask=mask, x0=x0,
116
+ ddim_use_original_steps=False,
117
+ noise_dropout=noise_dropout,
118
+ temperature=temperature,
119
+ score_corrector=score_corrector,
120
+ corrector_kwargs=corrector_kwargs,
121
+ x_T=x_T,
122
+ log_every_t=log_every_t,
123
+ unconditional_guidance_scale=unconditional_guidance_scale,
124
+ unconditional_conditioning=unconditional_conditioning,
125
+ z_ref=z_ref,
126
+ **kwargs
127
+ )
128
+ return samples, intermediates
129
+
130
+ @torch.no_grad()
131
+ def ddim_sampling(self, cond, shape,
132
+ x_T=None, ddim_use_original_steps=False,
133
+ callback=None, timesteps=None, quantize_denoised=False,
134
+ mask=None, x0=None, x0_step=None, img_callback=None, log_every_t=100,
135
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
136
+ unconditional_guidance_scale=1., unconditional_conditioning=None, z_ref=None,**kwargs):
137
+ device = self.model.betas.device
138
+ b = shape[0]
139
+ if x_T is None:
140
+ img = torch.randn(shape, device=device)
141
+ else:
142
+ img = x_T
143
+
144
+ if timesteps is None:
145
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
146
+ elif timesteps is not None and not ddim_use_original_steps:
147
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
148
+ timesteps = self.ddim_timesteps[:subset_end]
149
+
150
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
151
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
152
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
153
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
154
+
155
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
156
+
157
+ for i, step in enumerate(iterator):
158
+ index = total_steps - i - 1
159
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
160
+
161
+ if x0_step is not None and i < x0_step:
162
+ assert x0 is not None
163
+ img = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
164
+ # img = img_orig * mask + (1. - mask) * img
165
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
166
+ quantize_denoised=quantize_denoised, temperature=temperature,
167
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
168
+ corrector_kwargs=corrector_kwargs,
169
+ unconditional_guidance_scale=unconditional_guidance_scale,
170
+ z_ref=z_ref,
171
+ unconditional_conditioning=unconditional_conditioning,**kwargs)
172
+ img, pred_x0 = outs
173
+ if callback: callback(i)
174
+ if img_callback: img_callback(pred_x0, i)
175
+
176
+ if index % log_every_t == 0 or index == total_steps - 1:
177
+ intermediates['x_inter'].append(img)
178
+ intermediates['pred_x0'].append(pred_x0)
179
+
180
+ return img, intermediates
181
+
182
+ @torch.no_grad()
183
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
184
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
185
+ unconditional_guidance_scale=1., unconditional_conditioning=None, z_ref=None, drop_latent_guidance=1.0,**kwargs):
186
+ b, *_, device = *x.shape, x.device
187
+ if 'test_model_kwargs' in kwargs:
188
+ kwargs=kwargs['test_model_kwargs']
189
+ if f'inpaint_mask_{index}' in kwargs:
190
+ x = torch.cat([x, kwargs['inpaint_image'], kwargs[f'inpaint_mask_{index}']],dim=1)
191
+ print('using proxy mask', index)
192
+ else:
193
+ x = torch.cat([x, kwargs['inpaint_image'], kwargs[f'inpaint_mask']],dim=1)
194
+ if 'changed_pixels' in kwargs:
195
+ x = torch.cat([x, kwargs['changed_pixels']],dim=1)
196
+ elif 'rest' in kwargs:
197
+ x = torch.cat((x, kwargs['rest']), dim=1)
198
+ else:
199
+ raise Exception("kwargs must contain either 'test_model_kwargs' or 'rest' key")
200
+
201
+ # maybe should assert not both of these are true
202
+ # print('index', index)
203
+ if isinstance(drop_latent_guidance, list):
204
+ cur_drop_latent_guidance = drop_latent_guidance[index]
205
+ else:
206
+ cur_drop_latent_guidance = drop_latent_guidance
207
+ # print('cur drop guidance', cur_drop_latent_guidance)
208
+
209
+ if (unconditional_conditioning is None or unconditional_guidance_scale == 1.) and cur_drop_latent_guidance == 1.:
210
+ e_t = self.model.apply_model(x, t, c, z_ref=z_ref)
211
+ elif cur_drop_latent_guidance != 1.:
212
+ assert (unconditional_conditioning is None or unconditional_guidance_scale == 1.)
213
+ x_dropped = x.clone()
214
+ # print('x dropped shape', x_dropped.shape)
215
+ x_dropped[:,4:9] *= 0.0
216
+ x_in = torch.cat([x_dropped, x])
217
+ t_in = torch.cat([t] * 2)
218
+ z_ref_in = torch.cat([z_ref] * 2)
219
+ c_in = torch.cat([c] * 2)
220
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, z_ref=z_ref_in).chunk(2)
221
+ e_t = e_t_uncond + cur_drop_latent_guidance * (e_t - e_t_uncond)
222
+
223
+ else:
224
+ x_in = torch.cat([x] * 2)
225
+ t_in = torch.cat([t] * 2)
226
+ z_ref_in = torch.cat([z_ref] * 2)
227
+ # print('uncond shape', unconditional_conditioning.shape, 'c shape', c.shape)
228
+ c_in = torch.cat([unconditional_conditioning, c])
229
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, z_ref=z_ref_in).chunk(2)
230
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
231
+
232
+ if score_corrector is not None:
233
+ assert self.model.parameterization == "eps"
234
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
235
+
236
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
237
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
238
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
239
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
240
+ # select parameters corresponding to the currently considered timestep
241
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
242
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
243
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
244
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
245
+
246
+ # current prediction for x_0
247
+ if x.shape[1]!=4:
248
+ pred_x0 = (x[:,:4,:,:] - sqrt_one_minus_at * e_t) / a_t.sqrt()
249
+ else:
250
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
251
+ if quantize_denoised:
252
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
253
+ # direction pointing to x_t
254
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
255
+ noise = sigma_t * noise_like(dir_xt.shape, device, repeat_noise) * temperature
256
+ if noise_dropout > 0.:
257
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
258
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
259
+ return x_prev, pred_x0
260
+
261
+ @torch.no_grad()
262
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
263
+ # fast, but does not allow for exact reconstruction
264
+ # t serves as an index to gather the correct alphas
265
+ if use_original_steps:
266
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
267
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
268
+ else:
269
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
270
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
271
+
272
+ if noise is None:
273
+ noise = torch.randn_like(x0)
274
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
275
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
276
+
277
+ @torch.no_grad()
278
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
279
+ use_original_steps=False):
280
+
281
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
282
+ timesteps = timesteps[:t_start]
283
+
284
+ time_range = np.flip(timesteps)
285
+ total_steps = timesteps.shape[0]
286
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
287
+
288
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
289
+ x_dec = x_latent
290
+ for i, step in enumerate(iterator):
291
+ index = total_steps - i - 1
292
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
293
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
294
+ unconditional_guidance_scale=unconditional_guidance_scale,
295
+ unconditional_conditioning=unconditional_conditioning)
296
+ return x_dec
ldm/models/diffusion/ddpm.py ADDED
@@ -0,0 +1,1877 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ """
15
+ wild mixture of
16
+ https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
17
+ https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
18
+ https://github.com/CompVis/taming-transformers
19
+ -- merci
20
+ """
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torchvision
25
+ import numpy as np
26
+ import pytorch_lightning as pl
27
+ from torch.optim.lr_scheduler import LambdaLR
28
+ from einops import rearrange, repeat
29
+ from contextlib import contextmanager
30
+ from functools import partial
31
+ from tqdm import tqdm
32
+ from torchvision.utils import make_grid
33
+ # from pytorch_lightning.utilities.distributed import rank_zero_only
34
+ from pytorch_lightning.utilities.rank_zero import rank_zero_only
35
+ from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
36
+ from ldm.modules.ema import LitEma
37
+ from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
38
+ from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
39
+ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
40
+ from ldm.models.diffusion.ddim import DDIMSampler
41
+ from torchvision.transforms import Resize
42
+ import math
43
+ import time
44
+ import random
45
+ from torch.autograd import Variable
46
+ import copy
47
+ import os
48
+
49
+ __conditioning_keys__ = {'concat': 'c_concat',
50
+ 'crossattn': 'c_crossattn',
51
+ 'adm': 'y'}
52
+
53
+
54
+ def disabled_train(self, mode=True):
55
+ """Overwrite model.train with this function to make sure train/eval mode
56
+ does not change anymore."""
57
+ return self
58
+
59
+
60
+ def uniform_on_device(r1, r2, shape, device):
61
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
62
+
63
+
64
+ def rescale_zero_terminal_snr(betas):
65
+ """
66
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
67
+
68
+
69
+ Args:
70
+ betas (`torch.FloatTensor`):
71
+ the betas that the scheduler is being initialized with.
72
+
73
+ Returns:
74
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
75
+ """
76
+ # Convert betas to alphas_bar_sqrt
77
+ alphas = 1.0 - betas
78
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
79
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
80
+
81
+ # Store old values.
82
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
83
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
84
+
85
+ # Shift so the last timestep is zero.
86
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
87
+
88
+ # Scale so the first timestep is back to the old value.
89
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
90
+
91
+ # Convert alphas_bar_sqrt to betas
92
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
93
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
94
+ alphas = torch.cat([alphas_bar[0:1], alphas])
95
+ betas = 1 - alphas
96
+
97
+ return betas
98
+
99
+
100
+ class DDPM(pl.LightningModule):
101
+ # classic DDPM with Gaussian diffusion, in image space
102
+ def __init__(self,
103
+ unet_config,
104
+ timesteps=1000,
105
+ beta_schedule="linear",
106
+ loss_type="l2",
107
+ ckpt_path=None,
108
+ ignore_keys=[],
109
+ load_only_unet=False,
110
+ monitor="val/loss",
111
+ use_ema=True,
112
+ first_stage_key="image",
113
+ image_size=256,
114
+ channels=3,
115
+ log_every_t=100,
116
+ clip_denoised=True,
117
+ linear_start=1e-4,
118
+ linear_end=2e-2,
119
+ cosine_s=8e-3,
120
+ given_betas=None,
121
+ original_elbo_weight=0.,
122
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
123
+ l_simple_weight=1.,
124
+ conditioning_key=None,
125
+ parameterization="eps", # all assuming fixed variance schedules
126
+ scheduler_config=None,
127
+ use_positional_encodings=False,
128
+ learn_logvar=False,
129
+ logvar_init=0.,
130
+ u_cond_percent=0,
131
+ dropping_warped_latent_prob=0.,
132
+ remove_warped_latent=False,
133
+ gt_flag='GT',
134
+ sd_edit_step=850
135
+ ):
136
+ super().__init__()
137
+ assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
138
+ self.parameterization = parameterization
139
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
140
+ self.cond_stage_model = None
141
+ self.clip_denoised = clip_denoised
142
+ self.log_every_t = log_every_t
143
+ self.first_stage_key = first_stage_key
144
+ self.image_size = image_size
145
+ self.channels = channels
146
+ self.u_cond_percent=u_cond_percent
147
+ self.use_positional_encodings = use_positional_encodings
148
+ self.gt_flag = gt_flag
149
+ self.sd_edit_step = sd_edit_step
150
+
151
+ self.remove_warped_latent = remove_warped_latent
152
+ self.dropping_warped_latent_prob = dropping_warped_latent_prob
153
+
154
+ if dropping_warped_latent_prob > 0.0:
155
+ assert not self.remove_warped_latent
156
+
157
+
158
+ self.use_ema = use_ema
159
+ if self.use_ema:
160
+ self.model_ema = LitEma(self.model)
161
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
162
+
163
+ self.use_scheduler = scheduler_config is not None
164
+ if self.use_scheduler:
165
+ self.scheduler_config = scheduler_config
166
+
167
+ self.v_posterior = v_posterior
168
+ self.original_elbo_weight = original_elbo_weight
169
+ self.l_simple_weight = l_simple_weight
170
+
171
+ if monitor is not None:
172
+ self.monitor = monitor
173
+ if ckpt_path is not None:
174
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
175
+
176
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
177
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
178
+
179
+ self.model = DiffusionWrapper(unet_config, conditioning_key, ddpm_parent=self,
180
+ sqrt_alphas_cumprod=self.sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=self.sqrt_one_minus_alphas_cumprod)
181
+ count_params(self.model, verbose=True)
182
+
183
+ self.loss_type = loss_type
184
+
185
+ self.learn_logvar = learn_logvar
186
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
187
+ if self.learn_logvar:
188
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
189
+
190
+
191
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
192
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
193
+ if exists(given_betas):
194
+ betas = given_betas
195
+ else:
196
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
197
+ cosine_s=cosine_s)
198
+
199
+ # rescale beta
200
+ rescale_beta = True
201
+ if rescale_beta:
202
+ betas = rescale_zero_terminal_snr(torch.tensor(betas)).numpy()
203
+
204
+ alphas = 1. - betas
205
+ alphas_cumprod = np.cumprod(alphas, axis=0)
206
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
207
+
208
+ timesteps, = betas.shape
209
+ self.num_timesteps = int(timesteps)
210
+ self.linear_start = linear_start
211
+ self.linear_end = linear_end
212
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
213
+
214
+ to_torch = partial(torch.tensor, dtype=torch.float32)
215
+
216
+ self.register_buffer('betas', to_torch(betas))
217
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
218
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
219
+
220
+ # calculations for diffusion q(x_t | x_{t-1}) and others
221
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
222
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
223
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
224
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
225
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
226
+
227
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
228
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
229
+ 1. - alphas_cumprod) + self.v_posterior * betas
230
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
231
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
232
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
233
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
234
+ self.register_buffer('posterior_mean_coef1', to_torch(
235
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
236
+ self.register_buffer('posterior_mean_coef2', to_torch(
237
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
238
+
239
+ if self.parameterization == "eps":
240
+ lvlb_weights = self.betas ** 2 / (
241
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
242
+ elif self.parameterization == "x0":
243
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
244
+ else:
245
+ raise NotImplementedError("mu not supported")
246
+ # pr_odo how to choose this term
247
+ lvlb_weights[0] = lvlb_weights[1]
248
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
249
+ assert not torch.isnan(self.lvlb_weights).all()
250
+
251
+ @contextmanager
252
+ def ema_scope(self, context=None):
253
+ if self.use_ema:
254
+ self.model_ema.store(self.model.parameters())
255
+ self.model_ema.copy_to(self.model)
256
+ if context is not None:
257
+ print(f"{context}: Switched to EMA weights")
258
+ try:
259
+ yield None
260
+ finally:
261
+ if self.use_ema:
262
+ self.model_ema.restore(self.model.parameters())
263
+ if context is not None:
264
+ print(f"{context}: Restored training weights")
265
+
266
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
267
+ sd = torch.load(path, map_location="cpu")
268
+ if "state_dict" in list(sd.keys()):
269
+ sd = sd["state_dict"]
270
+ keys = list(sd.keys())
271
+ for k in keys:
272
+ for ik in ignore_keys:
273
+ if k.startswith(ik):
274
+ print("Deleting key {} from state_dict.".format(k))
275
+ del sd[k]
276
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
277
+ sd, strict=False)
278
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
279
+ if len(missing) > 0:
280
+ print(f"Missing Keys: {missing}")
281
+ if len(unexpected) > 0:
282
+ print(f"Unexpected Keys: {unexpected}")
283
+
284
+ def q_mean_variance(self, x_start, t):
285
+ """
286
+ Get the distribution q(x_t | x_0).
287
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
288
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
289
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
290
+ """
291
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
292
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
293
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
294
+ return mean, variance, log_variance
295
+
296
+ def predict_start_from_noise(self, x_t, t, noise):
297
+ return (
298
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
299
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
300
+ )
301
+
302
+ def q_posterior(self, x_start, x_t, t):
303
+ posterior_mean = (
304
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
305
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
306
+ )
307
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
308
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
309
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
310
+
311
+ def p_mean_variance(self, x, t, clip_denoised: bool):
312
+ model_out = self.model(x, t)
313
+ if self.parameterization == "eps":
314
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
315
+ elif self.parameterization == "x0":
316
+ x_recon = model_out
317
+ if clip_denoised:
318
+ x_recon.clamp_(-1., 1.)
319
+
320
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
321
+ return model_mean, posterior_variance, posterior_log_variance
322
+
323
+ @torch.no_grad()
324
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
325
+ b, *_, device = *x.shape, x.device
326
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
327
+ noise = noise_like(x.shape, device, repeat_noise)
328
+ # no noise when t == 0
329
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
330
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
331
+
332
+ @torch.no_grad()
333
+ def p_sample_loop(self, shape, return_intermediates=False):
334
+ device = self.betas.device
335
+ b = shape[0]
336
+ img = torch.randn(shape, device=device)
337
+ intermediates = [img]
338
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
339
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
340
+ clip_denoised=self.clip_denoised)
341
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
342
+ intermediates.append(img)
343
+ if return_intermediates:
344
+ return img, intermediates
345
+ return img
346
+
347
+ @torch.no_grad()
348
+ def sample(self, batch_size=16, return_intermediates=False):
349
+ image_size = self.image_size
350
+ channels = self.channels
351
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
352
+ return_intermediates=return_intermediates)
353
+
354
+ def q_sample(self, x_start, t, noise=None):
355
+ noise = default(noise, lambda: torch.randn_like(x_start))
356
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
357
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
358
+
359
+ def get_loss(self, pred, target, mean=True):
360
+ if self.loss_type == 'l1':
361
+ loss = (target - pred).abs()
362
+ if mean:
363
+ loss = loss.mean()
364
+ elif self.loss_type == 'l2':
365
+ if mean:
366
+ loss = torch.nn.functional.mse_loss(target, pred)
367
+ else:
368
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
369
+ else:
370
+ raise NotImplementedError("unknown loss type '{loss_type}'")
371
+
372
+ return loss
373
+
374
+ def p_losses(self, x_start, t, noise=None):
375
+ noise = default(noise, lambda: torch.randn_like(x_start))
376
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
377
+ model_out = self.model(x_noisy, t)
378
+
379
+ loss_dict = {}
380
+ if self.parameterization == "eps":
381
+ target = noise
382
+ elif self.parameterization == "x0":
383
+ target = x_start
384
+ else:
385
+ raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
386
+
387
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
388
+
389
+ log_prefix = 'train' if self.training else 'val'
390
+
391
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
392
+ loss_simple = loss.mean() * self.l_simple_weight
393
+
394
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
395
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
396
+
397
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
398
+
399
+ loss_dict.update({f'{log_prefix}/loss': loss})
400
+
401
+ return loss, loss_dict
402
+
403
+ def forward(self, x, *args, **kwargs):
404
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
405
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
406
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
407
+ return self.p_losses(x, t, *args, **kwargs)
408
+
409
+ def get_input(self, batch, k):
410
+ if k == "inpaint":
411
+ x = batch[self.gt_flag]
412
+ mask = batch['inpaint_mask']
413
+ inpaint = batch['inpaint_image']
414
+ reference = batch['ref_imgs']
415
+ clean_reference = batch['clean_reference']
416
+ grid_transformed = batch['grid_transformed']
417
+ changed_pixels = batch['changed_pixels']
418
+ else:
419
+ x = batch[k]
420
+ if len(x.shape) == 3:
421
+ x = x[..., None]
422
+ # x = rearrange(x, 'b h w c -> b c h w')
423
+ x = x.to(memory_format=torch.contiguous_format).float()
424
+ mask = mask.to(memory_format=torch.contiguous_format).float()
425
+ inpaint = inpaint.to(memory_format=torch.contiguous_format).float()
426
+ reference = reference.to(memory_format=torch.contiguous_format).float()
427
+ clean_reference = clean_reference.to(memory_format=torch.contiguous_format).float()
428
+ grid_transformed = grid_transformed.to(memory_format=torch.contiguous_format).float()
429
+ return x,inpaint,mask,reference, clean_reference, grid_transformed, changed_pixels
430
+
431
+ def shared_step(self, batch):
432
+ x = self.get_input(batch, self.first_stage_key)
433
+ loss, loss_dict = self(x)
434
+ return loss, loss_dict
435
+
436
+ def training_step(self, batch, batch_idx):
437
+ loss, loss_dict = self.shared_step(batch)
438
+
439
+ self.log_dict(loss_dict, prog_bar=True,
440
+ logger=True, on_step=True, on_epoch=True)
441
+
442
+ self.log("global_step", self.global_step,
443
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
444
+
445
+ if self.use_scheduler:
446
+ lr = self.optimizers().param_groups[0]['lr']
447
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
448
+
449
+ return loss
450
+
451
+ @torch.no_grad()
452
+ def validation_step(self, batch, batch_idx):
453
+ _, loss_dict_no_ema = self.shared_step(batch)
454
+ with self.ema_scope():
455
+ _, loss_dict_ema = self.shared_step(batch)
456
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
457
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
458
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
459
+
460
+ def on_train_batch_end(self, *args, **kwargs):
461
+ if self.use_ema:
462
+ self.model_ema(self.model)
463
+
464
+ def _get_rows_from_list(self, samples):
465
+ n_imgs_per_row = len(samples)
466
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
467
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
468
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
469
+ return denoise_grid
470
+
471
+ @torch.no_grad()
472
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
473
+ log = dict()
474
+ x = self.get_input(batch, self.first_stage_key)
475
+ N = min(x.shape[0], N)
476
+ n_row = min(x.shape[0], n_row)
477
+ x = x.to(self.device)[:N]
478
+ log["inputs"] = x
479
+
480
+ # get diffusion row
481
+ diffusion_row = list()
482
+ x_start = x[:n_row]
483
+
484
+ for t in range(self.num_timesteps):
485
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
486
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
487
+ t = t.to(self.device).long()
488
+ noise = torch.randn_like(x_start)
489
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
490
+ diffusion_row.append(x_noisy)
491
+
492
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
493
+
494
+ if sample:
495
+ # get denoise row
496
+ with self.ema_scope("Plotting"):
497
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
498
+
499
+ log["samples"] = samples
500
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
501
+
502
+ if return_keys:
503
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
504
+ return log
505
+ else:
506
+ return {key: log[key] for key in return_keys}
507
+ return log
508
+
509
+ def configure_optimizers(self):
510
+ lr = self.learning_rate
511
+ params = list(self.model.parameters())
512
+ if self.learn_logvar:
513
+ params = params + [self.logvar]
514
+ opt = torch.optim.AdamW(params, lr=lr)
515
+ return opt
516
+
517
+
518
+ class LatentDiffusion(DDPM):
519
+ """main class"""
520
+ def __init__(self,
521
+ first_stage_config,
522
+ cond_stage_config,
523
+ num_timesteps_cond=None,
524
+ cond_stage_key="image",
525
+ cond_stage_trainable=False,
526
+ concat_mode=True,
527
+ cond_stage_forward=None,
528
+ conditioning_key=None,
529
+ scale_factor=1.0,
530
+ scale_by_std=False,
531
+ context_embedding_dim=1024, # dim used for clip image encoder
532
+ *args, **kwargs):
533
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
534
+ self.scale_by_std = scale_by_std
535
+ assert self.num_timesteps_cond <= kwargs['timesteps']
536
+ # for backwards compatibility after implementation of DiffusionWrapper
537
+ if conditioning_key is None:
538
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
539
+ if cond_stage_config == '__is_unconditional__':
540
+ conditioning_key = None
541
+ ckpt_path = kwargs.pop("ckpt_path", None)
542
+ ignore_keys = kwargs.pop("ignore_keys", [])
543
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
544
+ self.learnable_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=True)
545
+ self.proj_out=nn.Linear(context_embedding_dim, 768)
546
+ self.concat_mode = concat_mode
547
+ self.cond_stage_trainable = cond_stage_trainable
548
+ self.cond_stage_key = cond_stage_key
549
+ try:
550
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
551
+ except:
552
+ self.num_downs = 0
553
+ if not scale_by_std:
554
+ self.scale_factor = scale_factor
555
+ else:
556
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
557
+ self.instantiate_first_stage(first_stage_config)
558
+ self.instantiate_cond_stage(cond_stage_config)
559
+ self.cond_stage_forward = cond_stage_forward
560
+ self.clip_denoised = False
561
+ self.bbox_tokenizer = None
562
+
563
+ self.restarted_from_ckpt = False
564
+ if ckpt_path is not None:
565
+ self.init_from_ckpt(ckpt_path, ignore_keys)
566
+ self.restarted_from_ckpt = True
567
+
568
+ def make_cond_schedule(self, ):
569
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
570
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
571
+ self.cond_ids[:self.num_timesteps_cond] = ids
572
+
573
+ @rank_zero_only
574
+ @torch.no_grad()
575
+ def on_train_batch_start(self, batch, batch_idx):
576
+ # only for very first batch
577
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
578
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
579
+ # set rescale weight to 1./std of encodings
580
+ print("### USING STD-RESCALING ###")
581
+ x = super().get_input(batch, self.first_stage_key)
582
+ x = x.to(self.device)
583
+ encoder_posterior = self.encode_first_stage(x)
584
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
585
+ del self.scale_factor
586
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
587
+ print(f"setting self.scale_factor to {self.scale_factor}")
588
+ print("### USING STD-RESCALING ###")
589
+
590
+ def register_schedule(self,
591
+ given_betas=None, beta_schedule="linear", timesteps=1000,
592
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
593
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
594
+
595
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
596
+ if self.shorten_cond_schedule:
597
+ self.make_cond_schedule()
598
+
599
+ def instantiate_first_stage(self, config):
600
+ model = instantiate_from_config(config)
601
+ self.first_stage_model = model.eval()
602
+ self.first_stage_model.train = disabled_train
603
+ for param in self.first_stage_model.parameters():
604
+ param.requires_grad = False
605
+
606
+ def instantiate_cond_stage(self, config):
607
+ if not self.cond_stage_trainable:
608
+ if config == "__is_first_stage__":
609
+ print("Using first stage also as cond stage.")
610
+ self.cond_stage_model = self.first_stage_model
611
+ elif config == "__is_unconditional__":
612
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
613
+ self.cond_stage_model = None
614
+ # self.be_unconditional = True
615
+ else:
616
+ model = instantiate_from_config(config)
617
+ self.cond_stage_model = model.eval()
618
+ self.cond_stage_model.train = disabled_train
619
+ for param in self.cond_stage_model.parameters():
620
+ param.requires_grad = False
621
+ else:
622
+ assert config != '__is_first_stage__'
623
+ assert config != '__is_unconditional__'
624
+ model = instantiate_from_config(config)
625
+ self.cond_stage_model = model
626
+
627
+
628
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
629
+ denoise_row = []
630
+ for zd in tqdm(samples, desc=desc):
631
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
632
+ force_not_quantize=force_no_decoder_quantization))
633
+ n_imgs_per_row = len(denoise_row)
634
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
635
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
636
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
637
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
638
+ return denoise_grid
639
+
640
+ def get_first_stage_encoding(self, encoder_posterior):
641
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
642
+ z = encoder_posterior.sample()
643
+ elif isinstance(encoder_posterior, torch.Tensor):
644
+ z = encoder_posterior
645
+ else:
646
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
647
+ return self.scale_factor * z
648
+
649
+ def get_learned_conditioning(self, c):
650
+ if self.cond_stage_forward is None:
651
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
652
+ c = self.cond_stage_model.encode(c)
653
+ if isinstance(c, DiagonalGaussianDistribution):
654
+ c = c.mode()
655
+ else:
656
+ c = self.cond_stage_model(c)
657
+ else:
658
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
659
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
660
+ return c
661
+
662
+
663
+ def meshgrid(self, h, w):
664
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
665
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
666
+
667
+ arr = torch.cat([y, x], dim=-1)
668
+ return arr
669
+
670
+ def delta_border(self, h, w):
671
+ """
672
+ :param h: height
673
+ :param w: width
674
+ :return: normalized distance to image border,
675
+ wtith min distance = 0 at border and max dist = 0.5 at image center
676
+ """
677
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
678
+ arr = self.meshgrid(h, w) / lower_right_corner
679
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
680
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
681
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
682
+ return edge_dist
683
+
684
+ def get_weighting(self, h, w, Ly, Lx, device):
685
+ weighting = self.delta_border(h, w)
686
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
687
+ self.split_input_params["clip_max_weight"], )
688
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
689
+
690
+ if self.split_input_params["tie_braker"]:
691
+ L_weighting = self.delta_border(Ly, Lx)
692
+ L_weighting = torch.clip(L_weighting,
693
+ self.split_input_params["clip_min_tie_weight"],
694
+ self.split_input_params["clip_max_tie_weight"])
695
+
696
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
697
+ weighting = weighting * L_weighting
698
+ return weighting
699
+
700
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # pr_odo load once not every time, shorten code
701
+ """
702
+ :param x: img of size (bs, c, h, w)
703
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
704
+ """
705
+ bs, nc, h, w = x.shape
706
+
707
+ # number of crops in image
708
+ Ly = (h - kernel_size[0]) // stride[0] + 1
709
+ Lx = (w - kernel_size[1]) // stride[1] + 1
710
+
711
+ if uf == 1 and df == 1:
712
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
713
+ unfold = torch.nn.Unfold(**fold_params)
714
+
715
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
716
+
717
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
718
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
719
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
720
+
721
+ elif uf > 1 and df == 1:
722
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
723
+ unfold = torch.nn.Unfold(**fold_params)
724
+
725
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
726
+ dilation=1, padding=0,
727
+ stride=(stride[0] * uf, stride[1] * uf))
728
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
729
+
730
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
731
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
732
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
733
+
734
+ elif df > 1 and uf == 1:
735
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
736
+ unfold = torch.nn.Unfold(**fold_params)
737
+
738
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
739
+ dilation=1, padding=0,
740
+ stride=(stride[0] // df, stride[1] // df))
741
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
742
+
743
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
744
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
745
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
746
+
747
+ else:
748
+ raise NotImplementedError
749
+
750
+ return fold, unfold, normalization, weighting
751
+
752
+ @torch.no_grad()
753
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
754
+ cond_key=None, return_original_cond=False, bs=None,get_mask=False,get_reference=False,get_inpaint=False, get_clean_ref=False, get_ref_rec=False,
755
+ get_changed_pixels=False):
756
+
757
+ x,inpaint,mask,reference, clean_reference, grid_transformed, changed_pixels = super().get_input(batch, k)
758
+ if bs is not None:
759
+ x = x[:bs]
760
+ inpaint = inpaint[:bs]
761
+ mask = mask[:bs]
762
+ reference = reference[:bs]
763
+ clean_reference = clean_reference[:bs]
764
+ grid_transformed = grid_transformed[:bs]
765
+ changed_pixels = changed_pixels[:bs]
766
+ x = x.to(self.device)
767
+ encoder_posterior = self.encode_first_stage(x)
768
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
769
+ encoder_posterior_inpaint = self.encode_first_stage(inpaint)
770
+ z_inpaint = self.get_first_stage_encoding(encoder_posterior_inpaint).detach()
771
+
772
+ encoder_posterior_inpaint = self.encode_first_stage(clean_reference)
773
+ z_reference = self.get_first_stage_encoding(encoder_posterior_inpaint).detach()
774
+ # breakpoint()
775
+ mask_resize = Resize([z.shape[-1],z.shape[-1]])(mask)
776
+ grid_resized = Resize([z.shape[-1],z.shape[-1]])(grid_transformed)
777
+ z_new = torch.cat((z,z_inpaint,mask_resize, grid_resized),dim=1)
778
+ # z_new = torch.cat((z,z_inpaint,mask_resize, changed_pixels, grid_resized),dim=1)
779
+ # z_new = torch.cat((z,z_inpaint,mask_resize, grid_resized),dim=1)
780
+
781
+ if self.model.conditioning_key is not None:
782
+ if cond_key is None:
783
+ cond_key = self.cond_stage_key
784
+ if cond_key != self.first_stage_key:
785
+ if cond_key in ['txt','caption', 'coordinates_bbox']:
786
+ xc = batch[cond_key]
787
+ elif cond_key == 'image':
788
+ xc = reference
789
+ elif cond_key == 'class_label':
790
+ xc = batch
791
+ else:
792
+ xc = super().get_input(batch, cond_key).to(self.device)
793
+ else:
794
+ xc = x
795
+ if not self.cond_stage_trainable or force_c_encode:
796
+ if isinstance(xc, dict) or isinstance(xc, list):
797
+ # import pudb; pudb.set_trace()
798
+ c = self.get_learned_conditioning(xc)
799
+ else:
800
+ c = self.get_learned_conditioning(xc.to(self.device))
801
+ c = self.proj_out(c)
802
+ c = c.float()
803
+ else:
804
+ c = xc
805
+ if bs is not None:
806
+ c = c[:bs]
807
+
808
+ if self.use_positional_encodings:
809
+ pos_x, pos_y = self.compute_latent_shifts(batch)
810
+ ckey = __conditioning_keys__[self.model.conditioning_key]
811
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
812
+
813
+ else:
814
+ c = None
815
+ xc = None
816
+ if self.use_positional_encodings:
817
+ pos_x, pos_y = self.compute_latent_shifts(batch)
818
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
819
+
820
+ # embed reference latent into cond
821
+ # c = [c, z_reference]
822
+ out = [z_new, c, z_reference]
823
+ if return_first_stage_outputs:
824
+ if self.first_stage_key=='inpaint':
825
+ xrec = self.decode_first_stage(z[:,:4,:,:])
826
+ else:
827
+ xrec = self.decode_first_stage(z)
828
+ out.extend([x, xrec])
829
+ if return_original_cond:
830
+ out.append(xc)
831
+ if get_mask:
832
+ out.append(mask)
833
+ if get_reference:
834
+ out.append(reference)
835
+ if get_inpaint:
836
+ out.append(inpaint)
837
+ if get_clean_ref:
838
+ out.append(clean_reference)
839
+ if get_ref_rec:
840
+ ref_rec = self.decode_first_stage(z_reference)
841
+ out.append(ref_rec)
842
+ if get_changed_pixels:
843
+ out.append(changed_pixels)
844
+ return out
845
+
846
+ @torch.no_grad()
847
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
848
+ if predict_cids:
849
+ if z.dim() == 4:
850
+ z = torch.argmax(z.exp(), dim=1).long()
851
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
852
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
853
+
854
+ z = 1. / self.scale_factor * z
855
+
856
+ if hasattr(self, "split_input_params"):
857
+ if self.split_input_params["patch_distributed_vq"]:
858
+ ks = self.split_input_params["ks"] # eg. (128, 128)
859
+ stride = self.split_input_params["stride"] # eg. (64, 64)
860
+ uf = self.split_input_params["vqf"]
861
+ bs, nc, h, w = z.shape
862
+ if ks[0] > h or ks[1] > w:
863
+ ks = (min(ks[0], h), min(ks[1], w))
864
+ print("reducing Kernel")
865
+
866
+ if stride[0] > h or stride[1] > w:
867
+ stride = (min(stride[0], h), min(stride[1], w))
868
+ print("reducing stride")
869
+
870
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
871
+
872
+ z = unfold(z) # (bn, nc * prod(**ks), L)
873
+ # 1. Reshape to img shape
874
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
875
+
876
+ # 2. apply model loop over last dim
877
+ if isinstance(self.first_stage_model, VQModelInterface):
878
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
879
+ force_not_quantize=predict_cids or force_not_quantize)
880
+ for i in range(z.shape[-1])]
881
+ else:
882
+
883
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
884
+ for i in range(z.shape[-1])]
885
+
886
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
887
+ o = o * weighting
888
+ # Reverse 1. reshape to img shape
889
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
890
+ # stitch crops together
891
+ decoded = fold(o)
892
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
893
+ return decoded
894
+ else:
895
+ if isinstance(self.first_stage_model, VQModelInterface):
896
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
897
+ else:
898
+ return self.first_stage_model.decode(z)
899
+
900
+ else:
901
+ if isinstance(self.first_stage_model, VQModelInterface):
902
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
903
+ else:
904
+ if self.first_stage_key=='inpaint':
905
+ return self.first_stage_model.decode(z[:,:4,:,:])
906
+ else:
907
+ return self.first_stage_model.decode(z)
908
+
909
+ # same as above but without decorator
910
+ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
911
+ if predict_cids:
912
+ if z.dim() == 4:
913
+ z = torch.argmax(z.exp(), dim=1).long()
914
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
915
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
916
+
917
+ z = 1. / self.scale_factor * z
918
+
919
+ if hasattr(self, "split_input_params"):
920
+ if self.split_input_params["patch_distributed_vq"]:
921
+ ks = self.split_input_params["ks"] # eg. (128, 128)
922
+ stride = self.split_input_params["stride"] # eg. (64, 64)
923
+ uf = self.split_input_params["vqf"]
924
+ bs, nc, h, w = z.shape
925
+ if ks[0] > h or ks[1] > w:
926
+ ks = (min(ks[0], h), min(ks[1], w))
927
+ print("reducing Kernel")
928
+
929
+ if stride[0] > h or stride[1] > w:
930
+ stride = (min(stride[0], h), min(stride[1], w))
931
+ print("reducing stride")
932
+
933
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
934
+
935
+ z = unfold(z) # (bn, nc * prod(**ks), L)
936
+ # 1. Reshape to img shape
937
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
938
+
939
+ # 2. apply model loop over last dim
940
+ if isinstance(self.first_stage_model, VQModelInterface):
941
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
942
+ force_not_quantize=predict_cids or force_not_quantize)
943
+ for i in range(z.shape[-1])]
944
+ else:
945
+
946
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
947
+ for i in range(z.shape[-1])]
948
+
949
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
950
+ o = o * weighting
951
+ # Reverse 1. reshape to img shape
952
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
953
+ # stitch crops together
954
+ decoded = fold(o)
955
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
956
+ return decoded
957
+ else:
958
+ if isinstance(self.first_stage_model, VQModelInterface):
959
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
960
+ else:
961
+ return self.first_stage_model.decode(z)
962
+
963
+ else:
964
+ if isinstance(self.first_stage_model, VQModelInterface):
965
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
966
+ else:
967
+ return self.first_stage_model.decode(z)
968
+
969
+ @torch.no_grad()
970
+ def encode_first_stage(self, x):
971
+ if hasattr(self, "split_input_params"):
972
+ if self.split_input_params["patch_distributed_vq"]:
973
+ ks = self.split_input_params["ks"] # eg. (128, 128)
974
+ stride = self.split_input_params["stride"] # eg. (64, 64)
975
+ df = self.split_input_params["vqf"]
976
+ self.split_input_params['original_image_size'] = x.shape[-2:]
977
+ bs, nc, h, w = x.shape
978
+ if ks[0] > h or ks[1] > w:
979
+ ks = (min(ks[0], h), min(ks[1], w))
980
+ print("reducing Kernel")
981
+
982
+ if stride[0] > h or stride[1] > w:
983
+ stride = (min(stride[0], h), min(stride[1], w))
984
+ print("reducing stride")
985
+
986
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
987
+ z = unfold(x) # (bn, nc * prod(**ks), L)
988
+ # Reshape to img shape
989
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
990
+
991
+ output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
992
+ for i in range(z.shape[-1])]
993
+
994
+ o = torch.stack(output_list, axis=-1)
995
+ o = o * weighting
996
+
997
+ # Reverse reshape to img shape
998
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
999
+ # stitch crops together
1000
+ decoded = fold(o)
1001
+ decoded = decoded / normalization
1002
+ return decoded
1003
+
1004
+ else:
1005
+ return self.first_stage_model.encode(x)
1006
+ else:
1007
+ return self.first_stage_model.encode(x)
1008
+
1009
+ def shared_step(self, batch, **kwargs):
1010
+ x, c, z_reference = self.get_input(batch, self.first_stage_key)
1011
+ loss = self(x, c, z_reference)
1012
+ return loss
1013
+
1014
+ def forward(self, x, c, z_reference, *args, **kwargs):
1015
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
1016
+ self.u_cond_prop=random.uniform(0, 1)
1017
+ if self.model.conditioning_key is not None:
1018
+ assert c is not None
1019
+ if self.cond_stage_trainable:
1020
+ c = self.get_learned_conditioning(c)
1021
+ c = self.proj_out(c)
1022
+
1023
+ if self.shorten_cond_schedule: # pr_odo: drop this option
1024
+ tc = self.cond_ids[t].to(self.device)
1025
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
1026
+
1027
+ if self.u_cond_prop<self.u_cond_percent:
1028
+ return self.p_losses(x, self.learnable_vector.repeat(x.shape[0],1,1), t, z_ref=z_reference, *args, **kwargs)
1029
+ else:
1030
+ return self.p_losses(x, c, t, z_ref=z_reference, *args, **kwargs)
1031
+
1032
+ def _rescale_annotations(self, bboxes, crop_coordinates): # pr_odo: move to dataset
1033
+ def rescale_bbox(bbox):
1034
+ x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
1035
+ y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
1036
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
1037
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
1038
+ return x0, y0, w, h
1039
+
1040
+ return [rescale_bbox(b) for b in bboxes]
1041
+
1042
+ def apply_model(self, x_noisy, t, cond, z_ref, return_ids=False):
1043
+
1044
+ if isinstance(cond, dict):
1045
+ # hybrid case, cond is exptected to be a dict
1046
+ pass
1047
+ else:
1048
+ if not isinstance(cond, list):
1049
+ cond = [cond]
1050
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
1051
+ cond = {key: cond}
1052
+
1053
+ if hasattr(self, "split_input_params"):
1054
+ raise ValueError('attempting to split input')
1055
+ # assert len(cond) == 1 # pr_odo can only deal with one conditioning atm
1056
+ # assert not return_ids
1057
+ # ks = self.split_input_params["ks"] # eg. (128, 128)
1058
+ # stride = self.split_input_params["stride"] # eg. (64, 64)
1059
+
1060
+ # h, w = x_noisy.shape[-2:]
1061
+
1062
+ # fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
1063
+
1064
+ # z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
1065
+ # # Reshape to img shape
1066
+ # z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
1067
+ # z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
1068
+
1069
+ # if self.cond_stage_key in ["image", "LR_image", "segmentation",
1070
+ # 'bbox_img'] and self.model.conditioning_key: # pr_odo check for completeness
1071
+ # c_key = next(iter(cond.keys())) # get key
1072
+ # c = next(iter(cond.values())) # get value
1073
+ # assert (len(c) == 1) # pr_odo extend to list with more than one elem
1074
+ # c = c[0] # get element
1075
+
1076
+ # c = unfold(c)
1077
+ # c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
1078
+
1079
+ # cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
1080
+
1081
+ # elif self.cond_stage_key == 'coordinates_bbox':
1082
+ # assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
1083
+
1084
+ # # assuming padding of unfold is always 0 and its dilation is always 1
1085
+ # n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
1086
+ # full_img_h, full_img_w = self.split_input_params['original_image_size']
1087
+ # # as we are operating on latents, we need the factor from the original image size to the
1088
+ # # spatial latent size to properly rescale the crops for regenerating the bbox annotations
1089
+ # num_downs = self.first_stage_model.encoder.num_resolutions - 1
1090
+ # rescale_latent = 2 ** (num_downs)
1091
+
1092
+ # # get top left positions of patches as conforming for the bbbox tokenizer, therefore we
1093
+ # # need to rescale the tl patch coordinates to be in between (0,1)
1094
+ # tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
1095
+ # rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
1096
+ # for patch_nr in range(z.shape[-1])]
1097
+
1098
+ # # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
1099
+ # patch_limits = [(x_tl, y_tl,
1100
+ # rescale_latent * ks[0] / full_img_w,
1101
+ # rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
1102
+ # # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
1103
+
1104
+ # # tokenize crop coordinates for the bounding boxes of the respective patches
1105
+ # patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
1106
+ # for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
1107
+ # print(patch_limits_tknzd[0].shape)
1108
+ # # cut tknzd crop position from conditioning
1109
+ # assert isinstance(cond, dict), 'cond must be dict to be fed into model'
1110
+ # cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
1111
+
1112
+ # adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
1113
+ # adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
1114
+ # adapted_cond = self.get_learned_conditioning(adapted_cond)
1115
+ # adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
1116
+
1117
+ # cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
1118
+
1119
+ # else:
1120
+ # cond_list = [cond for i in range(z.shape[-1])] # pr_odo make this more efficient
1121
+
1122
+ # # apply model by loop over crops
1123
+ # output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
1124
+ # assert not isinstance(output_list[0],
1125
+ # tuple) # pr_odo cant deal with multiple model outputs check this never happens
1126
+
1127
+ # o = torch.stack(output_list, axis=-1)
1128
+ # o = o * weighting
1129
+ # # Reverse reshape to img shape
1130
+ # o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
1131
+ # # stitch crops together
1132
+ # x_recon = fold(o) / normalization
1133
+
1134
+ else:
1135
+ # TODO address passing ref
1136
+ zeroed_out_warped_latent = x_noisy.clone()
1137
+ if self.remove_warped_latent:
1138
+ zeroed_out_warped_latent[:,4:8] *= 0.0
1139
+ x_recon = self.model(zeroed_out_warped_latent, t, z_ref=z_ref, **cond)
1140
+
1141
+ if isinstance(x_recon, tuple) and not return_ids:
1142
+ return x_recon[0]
1143
+ else:
1144
+ return x_recon
1145
+
1146
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
1147
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
1148
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
1149
+
1150
+ def _prior_bpd(self, x_start):
1151
+ """
1152
+ Get the prior KL term for the variational lower-bound, measured in
1153
+ bits-per-dim.
1154
+ This term can't be optimized, as it only depends on the encoder.
1155
+ :param x_start: the [N x C x ...] tensor of inputs.
1156
+ :return: a batch of [N] KL values (in bits), one per batch element.
1157
+ """
1158
+ batch_size = x_start.shape[0]
1159
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1160
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1161
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
1162
+ return mean_flat(kl_prior) / np.log(2.0)
1163
+
1164
+ def p_losses(self, x_start, cond, t, z_ref, noise=None):
1165
+ if self.first_stage_key == 'inpaint':
1166
+ # x_start=x_start[:,:4,:,:]
1167
+ latents = x_start[:,:4,:,:]
1168
+ latents_warped = x_start[:,4:8,:,:]
1169
+ noise = default(noise, lambda: torch.randn_like(x_start[:,:4,:,:]))
1170
+ # offset noise
1171
+ # noise += 0.05 * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
1172
+ # TODO address the reference latent
1173
+ # warped_mask = t > self.sd_edit_step
1174
+
1175
+ x_noisy = self.q_sample(x_start=latents, t=t, noise=noise)
1176
+ # warped_noisy = self.q_sample(x_start=latents_warped, t=t, noise=noise)
1177
+ # x_noisy[warped_mask] = warped_noisy[warped_mask]
1178
+
1179
+ # TODO add here
1180
+ remove_latent_prob=random.uniform(0, 1)
1181
+
1182
+ if remove_latent_prob < self.dropping_warped_latent_prob:
1183
+ modified_x_start = x_start.clone()
1184
+ # dropping warped latent and mask
1185
+ modified_x_start[:, 4:9] *= 0.0
1186
+
1187
+ # print('using modified x start')
1188
+ x_noisy = torch.cat((x_noisy,modified_x_start[:,4:,:,:]),dim=1)
1189
+ else:
1190
+ x_noisy = torch.cat((x_noisy,x_start[:,4:,:,:]),dim=1)
1191
+ else:
1192
+ noise = default(noise, lambda: torch.randn_like(x_start))
1193
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
1194
+ model_output = self.apply_model(x_noisy, t, cond, z_ref)
1195
+
1196
+ loss_dict = {}
1197
+ prefix = 'train' if self.training else 'val'
1198
+
1199
+ if self.parameterization == "x0":
1200
+ target = x_start
1201
+ elif self.parameterization == "eps":
1202
+ target = noise
1203
+ else:
1204
+ raise NotImplementedError()
1205
+
1206
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
1207
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
1208
+
1209
+ self.logvar = self.logvar.to(self.device)
1210
+ logvar_t = self.logvar[t].to(self.device)
1211
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
1212
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
1213
+ if self.learn_logvar:
1214
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
1215
+ loss_dict.update({'logvar': self.logvar.data.mean()})
1216
+
1217
+ loss = self.l_simple_weight * loss.mean()
1218
+
1219
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
1220
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
1221
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
1222
+ loss += (self.original_elbo_weight * loss_vlb)
1223
+ loss_dict.update({f'{prefix}/loss': loss})
1224
+
1225
+ return loss, loss_dict
1226
+
1227
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
1228
+ return_x0=False, score_corrector=None, corrector_kwargs=None, z_ref=None):
1229
+ t_in = t
1230
+ #TODO pass reference
1231
+ model_out = self.apply_model(x, t_in, c, z_ref=z_ref, return_ids=return_codebook_ids)
1232
+
1233
+ if score_corrector is not None:
1234
+ assert self.parameterization == "eps"
1235
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
1236
+
1237
+ if return_codebook_ids:
1238
+ model_out, logits = model_out
1239
+
1240
+ if self.parameterization == "eps":
1241
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
1242
+ elif self.parameterization == "x0":
1243
+ x_recon = model_out
1244
+ else:
1245
+ raise NotImplementedError()
1246
+
1247
+ if clip_denoised:
1248
+ x_recon.clamp_(-1., 1.)
1249
+ if quantize_denoised:
1250
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
1251
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
1252
+ if return_codebook_ids:
1253
+ return model_mean, posterior_variance, posterior_log_variance, logits
1254
+ elif return_x0:
1255
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
1256
+ else:
1257
+ return model_mean, posterior_variance, posterior_log_variance
1258
+
1259
+ @torch.no_grad()
1260
+ def p_sample(self, x, c, t, z_ref=None, clip_denoised=False, repeat_noise=False,
1261
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
1262
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
1263
+ b, *_, device = *x.shape, x.device
1264
+ outputs = self.p_mean_variance(x=x, c=c, t=t, z_ref=z_ref, clip_denoised=clip_denoised,
1265
+ return_codebook_ids=return_codebook_ids,
1266
+ quantize_denoised=quantize_denoised,
1267
+ return_x0=return_x0,
1268
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1269
+ if return_codebook_ids:
1270
+ raise DeprecationWarning("Support dropped.")
1271
+ model_mean, _, model_log_variance, logits = outputs
1272
+ elif return_x0:
1273
+ model_mean, _, model_log_variance, x0 = outputs
1274
+ else:
1275
+ model_mean, _, model_log_variance = outputs
1276
+
1277
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
1278
+ if noise_dropout > 0.:
1279
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
1280
+ # no noise when t == 0
1281
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
1282
+
1283
+ if return_codebook_ids:
1284
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
1285
+ if return_x0:
1286
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
1287
+ else:
1288
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
1289
+
1290
+ @torch.no_grad()
1291
+ def progressive_denoising(self, cond, shape, z_ref=None, verbose=True, callback=None, quantize_denoised=False,
1292
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
1293
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
1294
+ log_every_t=None):
1295
+ if not log_every_t:
1296
+ log_every_t = self.log_every_t
1297
+ timesteps = self.num_timesteps
1298
+ if batch_size is not None:
1299
+ b = batch_size if batch_size is not None else shape[0]
1300
+ shape = [batch_size] + list(shape)
1301
+ else:
1302
+ b = batch_size = shape[0]
1303
+ if x_T is None:
1304
+ img = torch.randn(shape, device=self.device)
1305
+ else:
1306
+ img = x_T
1307
+ intermediates = []
1308
+ if cond is not None:
1309
+ if isinstance(cond, dict):
1310
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1311
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1312
+ else:
1313
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1314
+
1315
+ if start_T is not None:
1316
+ timesteps = min(timesteps, start_T)
1317
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
1318
+ total=timesteps) if verbose else reversed(
1319
+ range(0, timesteps))
1320
+ if type(temperature) == float:
1321
+ temperature = [temperature] * timesteps
1322
+
1323
+ for i in iterator:
1324
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1325
+ if self.shorten_cond_schedule:
1326
+ assert self.model.conditioning_key != 'hybrid'
1327
+ tc = self.cond_ids[ts].to(cond.device)
1328
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1329
+
1330
+ img, x0_partial = self.p_sample(img, cond, ts, z_ref=z_ref,
1331
+ clip_denoised=self.clip_denoised,
1332
+ quantize_denoised=quantize_denoised, return_x0=True,
1333
+ temperature=temperature[i], noise_dropout=noise_dropout,
1334
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1335
+ if mask is not None:
1336
+ assert x0 is not None
1337
+ img_orig = self.q_sample(x0, ts)
1338
+ img = img_orig * mask + (1. - mask) * img
1339
+
1340
+ if i % log_every_t == 0 or i == timesteps - 1:
1341
+ intermediates.append(x0_partial)
1342
+ if callback: callback(i)
1343
+ if img_callback: img_callback(img, i)
1344
+ return img, intermediates
1345
+
1346
+ @torch.no_grad()
1347
+ def p_sample_loop(self, cond, shape, z_ref=None, return_intermediates=False,
1348
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
1349
+ mask=None, x0=None, img_callback=None, start_T=None,
1350
+ log_every_t=None):
1351
+
1352
+ if not log_every_t:
1353
+ log_every_t = self.log_every_t
1354
+ device = self.betas.device
1355
+ b = shape[0]
1356
+ if x_T is None:
1357
+ img = torch.randn(shape, device=device)
1358
+ else:
1359
+ img = x_T
1360
+
1361
+ intermediates = [img]
1362
+ if timesteps is None:
1363
+ timesteps = self.num_timesteps
1364
+
1365
+ if start_T is not None:
1366
+ timesteps = min(timesteps, start_T)
1367
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1368
+ range(0, timesteps))
1369
+
1370
+ if mask is not None:
1371
+ assert x0 is not None
1372
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1373
+
1374
+ for i in iterator:
1375
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
1376
+ if self.shorten_cond_schedule:
1377
+ assert self.model.conditioning_key != 'hybrid'
1378
+ tc = self.cond_ids[ts].to(cond.device)
1379
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1380
+
1381
+ img = self.p_sample(img, cond, ts, z_ref=z_ref,
1382
+ clip_denoised=self.clip_denoised,
1383
+ quantize_denoised=quantize_denoised)
1384
+ if mask is not None:
1385
+ img_orig = self.q_sample(x0, ts)
1386
+ img = img_orig * mask + (1. - mask) * img
1387
+
1388
+ if i % log_every_t == 0 or i == timesteps - 1:
1389
+ intermediates.append(img)
1390
+ if callback: callback(i)
1391
+ if img_callback: img_callback(img, i)
1392
+
1393
+ if return_intermediates:
1394
+ return img, intermediates
1395
+ return img
1396
+
1397
+ @torch.no_grad()
1398
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1399
+ verbose=True, timesteps=None, quantize_denoised=False,
1400
+ mask=None, x0=None, shape=None,**kwargs):
1401
+ if shape is None:
1402
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
1403
+ if cond is not None:
1404
+ if isinstance(cond, dict):
1405
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1406
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1407
+ else:
1408
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1409
+ return self.p_sample_loop(cond,
1410
+ shape,
1411
+ return_intermediates=return_intermediates, x_T=x_T,
1412
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
1413
+ mask=mask, x0=x0)
1414
+
1415
+ @torch.no_grad()
1416
+ def sample_log(self,cond,batch_size,ddim, ddim_steps, z_ref=None, full_z=None,**kwargs):
1417
+
1418
+ if ddim:
1419
+ ddim_sampler = DDIMSampler(self)
1420
+ shape = (self.channels, self.image_size, self.image_size)
1421
+ z_inpaint = full_z[:,4:8]
1422
+ step=1
1423
+
1424
+
1425
+ samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
1426
+ shape,cond, z_ref=z_ref,verbose=False, x0=z_inpaint,
1427
+ x0_step=step,**kwargs)
1428
+
1429
+ else:
1430
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1431
+ return_intermediates=True,**kwargs)
1432
+
1433
+ return samples, intermediates
1434
+
1435
+
1436
+ @torch.no_grad()
1437
+ def log_images(self, batch, N=4, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1438
+ quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,
1439
+ plot_diffusion_rows=True, **kwargs):
1440
+
1441
+ use_ddim = ddim_steps is not None
1442
+
1443
+ log = dict()
1444
+
1445
+ z, c, z_ref, x, xrec, xc, mask, reference, inpaint_img, clean_ref, ref_rec, changed_pixels = self.get_input(batch, self.first_stage_key,
1446
+ return_first_stage_outputs=True,
1447
+ force_c_encode=True,
1448
+ return_original_cond=True,
1449
+ get_mask=True,
1450
+ get_reference=True,
1451
+ get_inpaint=True,
1452
+ bs=N,
1453
+ get_clean_ref=True,
1454
+ get_ref_rec=True,
1455
+ get_changed_pixels=True)
1456
+
1457
+ N = min(x.shape[0], N)
1458
+ n_row = min(x.shape[0], n_row)
1459
+ log["inputs"] = x
1460
+ log["reconstruction"] = xrec
1461
+ log["mask"]=mask
1462
+ log['changed_pixels'] = changed_pixels
1463
+ log["warped"]=inpaint_img
1464
+ log["original"] = clean_ref
1465
+ log["ref_rec"] = ref_rec
1466
+ # log["reference"]=reference
1467
+ if self.model.conditioning_key is not None:
1468
+ if hasattr(self.cond_stage_model, "decode"):
1469
+ xc = self.cond_stage_model.decode(c)
1470
+ log["conditioning"] = xc
1471
+ elif self.cond_stage_key in ["caption","txt"]:
1472
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key])
1473
+ log["conditioning"] = xc
1474
+ elif self.cond_stage_key == 'class_label':
1475
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
1476
+ log['conditioning'] = xc
1477
+ elif isimage(xc):
1478
+ log["conditioning"] = xc
1479
+ if ismap(xc):
1480
+ log["original_conditioning"] = self.to_rgb(xc)
1481
+
1482
+ if plot_diffusion_rows:
1483
+ # get diffusion row
1484
+ diffusion_row = list()
1485
+ z_start = z[:n_row]
1486
+ for t in range(self.num_timesteps):
1487
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1488
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1489
+ t = t.to(self.device).long()
1490
+ noise = torch.randn_like(z_start)
1491
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1492
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1493
+
1494
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1495
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1496
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1497
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1498
+ log["diffusion_row"] = diffusion_grid
1499
+
1500
+ if sample:
1501
+ # get denoise row
1502
+ with self.ema_scope("Plotting"):
1503
+ if self.first_stage_key=='inpaint':
1504
+ samples, z_denoise_row = self.sample_log(cond=c, z_ref=z_ref,batch_size=N,ddim=use_ddim, full_z=z,
1505
+ ddim_steps=ddim_steps,eta=ddim_eta,rest=z[:,4:,:,:])
1506
+ else:
1507
+ samples, z_denoise_row = self.sample_log(cond=c, z_ref=z_ref,batch_size=N,ddim=use_ddim,
1508
+ ddim_steps=ddim_steps,eta=ddim_eta)
1509
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1510
+ x_samples = self.decode_first_stage(samples)
1511
+ log["samples"] = x_samples
1512
+ if plot_denoise_rows:
1513
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1514
+ log["denoise_row"] = denoise_grid
1515
+
1516
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
1517
+ self.first_stage_model, IdentityFirstStage):
1518
+ # also display when quantizing x0 while sampling
1519
+ with self.ema_scope("Plotting Quantized Denoised"):
1520
+ samples, z_denoise_row = self.sample_log(cond=c, z_ref=z_ref, batch_size=N,ddim=use_ddim,
1521
+ ddim_steps=ddim_steps,eta=ddim_eta,
1522
+ quantize_denoised=True)
1523
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1524
+ # quantize_denoised=True)
1525
+ x_samples = self.decode_first_stage(samples.to(self.device))
1526
+ log["samples_x0_quantized"] = x_samples
1527
+
1528
+ if inpaint:
1529
+ # make a simple center square
1530
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
1531
+ mask = torch.ones(N, h, w).to(self.device)
1532
+ # zeros will be filled in
1533
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1534
+ mask = mask[:, None, ...]
1535
+ with self.ema_scope("Plotting Inpaint"):
1536
+
1537
+ samples, _ = self.sample_log(cond=c, z_ref=z_ref,batch_size=N,ddim=use_ddim, eta=ddim_eta,
1538
+ ddim_steps=ddim_steps, x0=z[:N,:4], mask=mask)
1539
+ x_samples = self.decode_first_stage(samples.to(self.device))
1540
+ log["samples_inpainting"] = x_samples
1541
+ log["mask"] = mask
1542
+
1543
+ # outpaint
1544
+ with self.ema_scope("Plotting Outpaint"):
1545
+ samples, _ = self.sample_log(cond=c, z_ref=z_ref, batch_size=N, ddim=use_ddim,eta=ddim_eta,
1546
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1547
+ x_samples = self.decode_first_stage(samples.to(self.device))
1548
+ log["samples_outpainting"] = x_samples
1549
+
1550
+ if plot_progressive_rows:
1551
+ with self.ema_scope("Plotting Progressives"):
1552
+ img, progressives = self.progressive_denoising(c,
1553
+ z_ref=z_ref,
1554
+ shape=(self.channels, self.image_size, self.image_size),
1555
+ batch_size=N)
1556
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1557
+ log["progressive_row"] = prog_row
1558
+
1559
+ if return_keys:
1560
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1561
+ return log
1562
+ else:
1563
+ return {key: log[key] for key in return_keys}
1564
+ return log
1565
+
1566
+ def configure_optimizers(self):
1567
+ lr = self.learning_rate
1568
+ params = list(self.model.parameters())
1569
+
1570
+
1571
+
1572
+ if self.cond_stage_trainable:
1573
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1574
+ # need to add final_ln.parameters() TODO
1575
+ params = params + list(self.cond_stage_model.final_ln.parameters())+list(self.cond_stage_model.mapper.parameters())+list(self.proj_out.parameters())
1576
+ if self.learn_logvar:
1577
+ print('Diffusion model optimizing logvar')
1578
+ params.append(self.logvar)
1579
+ params.append(self.learnable_vector)
1580
+ opt = torch.optim.AdamW(params, lr=lr)
1581
+ if self.use_scheduler:
1582
+ assert 'target' in self.scheduler_config
1583
+ scheduler = instantiate_from_config(self.scheduler_config)
1584
+
1585
+ print("Setting up LambdaLR scheduler...")
1586
+ scheduler = [
1587
+ {
1588
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1589
+ 'interval': 'step',
1590
+ 'frequency': 1
1591
+ }]
1592
+ return [opt], scheduler
1593
+ return opt
1594
+
1595
+ @torch.no_grad()
1596
+ def to_rgb(self, x):
1597
+ x = x.float()
1598
+ if not hasattr(self, "colorize"):
1599
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1600
+ x = nn.functional.conv2d(x, weight=self.colorize)
1601
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1602
+ return x
1603
+
1604
+
1605
+ class DiffusionWrapper(pl.LightningModule):
1606
+ def __init__(self, diff_model_config, conditioning_key, sqrt_alphas_cumprod=None, sqrt_one_minus_alphas_cumprod=None, ddpm_parent=None):
1607
+ super().__init__()
1608
+ self.diffusion_model = instantiate_from_config(diff_model_config)
1609
+ self.conditioning_key = conditioning_key
1610
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'crossref', 'rewarp', 'rewarp_grid']
1611
+ # self.save_folder = '/mnt/localssd/collage_latents_lovely_new_data'
1612
+ # self.save_counter = 0
1613
+ # self.save_subfolder = None
1614
+
1615
+ # os.makedirs(self.save_folder, exist_ok=True)
1616
+ self.sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod
1617
+ self.sqrt_alphas_cumprod = sqrt_alphas_cumprod
1618
+ self.og_grid = None
1619
+ self.transformed_grid = None
1620
+ if self.conditioning_key == 'crossref' or 'rewarp' in self.conditioning_key:
1621
+ self.reference_model = copy.deepcopy(self.diffusion_model)
1622
+
1623
+
1624
+ def get_grid(self, size, batch_size):
1625
+ # raise ValueError TODO Fix
1626
+ y = np.repeat(np.arange(size)[None, ...], size)
1627
+ y = y.reshape(size, size)
1628
+ x = y.transpose()
1629
+ out = np.stack([y,x], 0)
1630
+ out = torch.tensor(out)
1631
+ out = out.unsqueeze(0)
1632
+ out = out.repeat(batch_size, 1, 1, 1)
1633
+ return out
1634
+
1635
+ def compute_correspondences(self, grid_transformed, masks, original_size=512, add_grids=False):
1636
+ # create the correspondence map for all the needed sizes
1637
+ corresp_indices = {}
1638
+ batch_size = grid_transformed.shape[0]
1639
+
1640
+ if self.og_grid is None:
1641
+ grid_og = self.get_grid(original_size, batch_size).to(grid_transformed.device) / float(original_size)
1642
+ else:
1643
+ grid_og = self.og_grid
1644
+
1645
+
1646
+ for d in [8, 16, 32, 64]:
1647
+ resized_grid_1 = torchvision.transforms.functional.resize(grid_og, size=(d,d))
1648
+ resized_grid_2 = torchvision.transforms.functional.resize(grid_transformed, size=(d,d))
1649
+ # the mask is at 64x64. 1 means exist in image. 0 is missing (needs inpainting)
1650
+ resized_mask = torchvision.transforms.functional.resize(masks, size=(d,d))
1651
+
1652
+ missing_mask = resized_mask.squeeze(1) < 0.7 #torch.sum(resized_grid_2, dim=1) < 0.1
1653
+
1654
+ src_grid = resized_grid_1.permute(0,2,3,1) # B x 2 x d x d
1655
+ guide_grid = resized_grid_2.permute(0,2,3,1)
1656
+
1657
+ src1_flat = src_grid.reshape(batch_size, d**2, 2)
1658
+ src2_flat = guide_grid.reshape(batch_size, d**2, 2)
1659
+ missing_flat = missing_mask.reshape(batch_size, d**2)
1660
+
1661
+ torch_dist = torch.cdist(src2_flat.float(), src1_flat.float())
1662
+ # print('torch dist shape for d', d, torch_dist.shape)
1663
+
1664
+ # missing_masks[d] = missing_flat
1665
+ min_indices = torch.argmin(torch_dist, dim=-1)
1666
+ # min_indices.requires_grad = False
1667
+ # missing_flat.requires_grad = False
1668
+ if add_grids:
1669
+ corresp_indices[d] = (min_indices, missing_flat, resized_grid_1, resized_grid_2)
1670
+ else:
1671
+ corresp_indices[d] = (min_indices, missing_flat)
1672
+ return corresp_indices #, missing_masks
1673
+
1674
+ def q_sample(self, x_start, t, noise=None):
1675
+ noise = default(noise, lambda: torch.randn_like(x_start))
1676
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(x_start.device)
1677
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(x_start.device)
1678
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x_start.shape) * x_start +
1679
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
1680
+
1681
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, z_ref = None):
1682
+ num_ch = x.shape[1]
1683
+ # print(num_ch)
1684
+ if num_ch >= 11:
1685
+ self.transformed_grid = x[:, -2:]
1686
+ x = x[:, :-2]
1687
+ # else:
1688
+ # grid_transformed = None
1689
+
1690
+ if self.conditioning_key is None:
1691
+ out = self.diffusion_model(x, t)
1692
+ elif self.conditioning_key == 'concat':
1693
+ xc = torch.cat([x] + c_concat, dim=1)
1694
+ out = self.diffusion_model(xc, t)
1695
+ elif self.conditioning_key == 'crossattn':
1696
+ cc = torch.cat(c_crossattn, 1)
1697
+ out = self.diffusion_model(x, t, context=cc)
1698
+
1699
+ # self.save_subfolder = f'{self.save_folder}/saved_{time.time()}'
1700
+ # os.makedirs(self.save_subfolder, exist_ok=True)
1701
+ # # just for saving purposes
1702
+ # assert z_ref is not None
1703
+ # noisy_z_ref = self.q_sample(z_ref, t)
1704
+ # # z_new = torch.cat((z,z_inpaint,mask_resize),dim=1)
1705
+
1706
+ # mask = x[:, -1:]
1707
+ # z_ref_concat = torch.cat([noisy_z_ref, z_ref, mask], dim=1)
1708
+
1709
+ # correspondeces = self.compute_correspondences(self.transformed_grid, mask, original_size=512, add_grids=True)
1710
+
1711
+ # if self.save_counter < 50:
1712
+ # torch.save(x.cpu(), f'{self.save_subfolder}/z_collage_concat.pt' )
1713
+ # torch.save(z_ref_concat.cpu(), f'{self.save_subfolder}/z_ref_concat.pt')
1714
+ # torch.save(correspondeces, f'{self.save_subfolder}/corresps.pt')
1715
+ # self.save_counter += 1
1716
+
1717
+
1718
+ elif self.conditioning_key == 'hybrid':
1719
+ xc = torch.cat([x] + c_concat, dim=1)
1720
+ cc = torch.cat(c_crossattn, 1)
1721
+ out = self.diffusion_model(xc, t, context=cc)
1722
+ elif self.conditioning_key == 'adm':
1723
+ cc = c_crossattn[0]
1724
+ out = self.diffusion_model(x, t, y=cc)
1725
+ # elif self.conditioning_key == 'crossref':
1726
+ # cc = torch.cat(c_crossattn, 1)
1727
+ # # qsample z_ref by t to add noise
1728
+ # # so have noisy z_ref + z_ref + mask
1729
+ # # compute contexts
1730
+ # assert z_ref is not None
1731
+ # noisy_z_ref = self.q_sample(z_ref, t)
1732
+ # # z_new = torch.cat((z,z_inpaint,mask_resize),dim=1)
1733
+ # mask = x[:, -1:]
1734
+ # z_ref_concat = torch.cat([noisy_z_ref, z_ref, mask], dim=1)
1735
+
1736
+
1737
+ # # compute contexts
1738
+ # _, contexts = self.reference_model(z_ref_concat, t, context=cc, get_contexts=True)
1739
+
1740
+ # # input diffusion model with contexts
1741
+ # out = self.diffusion_model(x, t, context=cc, passed_contexts=contexts)
1742
+
1743
+ elif self.conditioning_key == 'rewarp' or self.conditioning_key == 'crossref': # also include the crossref for now
1744
+ cc = torch.cat(c_crossattn, 1)
1745
+ # qsample z_ref by t to add noise
1746
+ # so have noisy z_ref + z_ref + mask
1747
+ # compute contexts
1748
+ if self.conditioning_key == 'crossref':
1749
+ raise ValueError('currently not implemented properly. please fix attention')
1750
+ assert z_ref is not None
1751
+ noisy_z_ref = self.q_sample(z_ref, t)
1752
+ # z_new = torch.cat((z,z_inpaint,mask_resize),dim=1)
1753
+
1754
+ # mask = x[:, -2:-1] # mask and new regions
1755
+ # changed_pixels = x[:, -1:]
1756
+ # z_ref_concat = torch.cat([noisy_z_ref, z_ref, mask, changed_pixels], dim=1)
1757
+ mask = x[:, -1:] # mask and new regions
1758
+ z_ref_concat = torch.cat([noisy_z_ref, z_ref, mask], dim=1)
1759
+
1760
+
1761
+ init_corresp_time = time.time()
1762
+ correspondeces = self.compute_correspondences(self.transformed_grid, mask, original_size=512) ## TODO make input dependent
1763
+ final_corresp_time = time.time()
1764
+
1765
+ # compute contexts
1766
+ _, contexts = self.reference_model(z_ref_concat, t, context=cc, get_contexts=True)
1767
+ # input diffusion model with contexts
1768
+ out = self.diffusion_model(x, t, context=cc, passed_contexts=contexts, corresp=correspondeces)
1769
+
1770
+ elif self.conditioning_key == 'rewarp_grid':
1771
+ grid_og = self.get_grid(64, batch_size=x.shape[0]).to(x.device) / 64.0
1772
+ cc = torch.cat(c_crossattn, 1)
1773
+ # qsample z_ref by t to add noise
1774
+ # so have noisy z_ref + z_ref + mask
1775
+ # compute contexts
1776
+
1777
+ assert z_ref is not None
1778
+ noisy_z_ref = self.q_sample(z_ref, t)
1779
+ # z_new = torch.cat((z,z_inpaint,mask_resize),dim=1)
1780
+
1781
+ # mask = x[:, -2:-1] # mask and new regions
1782
+ # changed_pixels = x[:, -1:]
1783
+ # z_ref_concat = torch.cat([noisy_z_ref, z_ref, mask, changed_pixels], dim=1)
1784
+ mask = x[:, -1:] # mask and new regions
1785
+ z_ref_concat = torch.cat([noisy_z_ref, z_ref, mask, grid_og], dim=1)
1786
+ x = torch.cat([x, grid_og], dim=1)
1787
+
1788
+ correspondeces = self.compute_correspondences(self.transformed_grid, mask, original_size=512) ## TODO make input dependent
1789
+
1790
+ # compute contexts
1791
+ _, contexts = self.reference_model(z_ref_concat, t, context=cc, get_contexts=True)
1792
+ # input diffusion model with contexts
1793
+ out = self.diffusion_model(x, t, context=cc, passed_contexts=contexts, corresp=correspondeces)
1794
+
1795
+ else:
1796
+ raise NotImplementedError()
1797
+
1798
+ return out
1799
+
1800
+
1801
+ class Layout2ImgDiffusion(LatentDiffusion):
1802
+ # pr_odo: move all layout-specific hacks to this class
1803
+ def __init__(self, cond_stage_key, *args, **kwargs):
1804
+ assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
1805
+ super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
1806
+
1807
+ def log_images(self, batch, N=8, *args, **kwargs):
1808
+ logs = super().log_images(batch=batch, N=N, *args, **kwargs)
1809
+
1810
+ key = 'train' if self.training else 'validation'
1811
+ dset = self.trainer.datamodule.datasets[key]
1812
+ mapper = dset.conditional_builders[self.cond_stage_key]
1813
+
1814
+ bbox_imgs = []
1815
+ map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
1816
+ for tknzd_bbox in batch[self.cond_stage_key][:N]:
1817
+ bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
1818
+ bbox_imgs.append(bboximg)
1819
+
1820
+ cond_img = torch.stack(bbox_imgs, dim=0)
1821
+ logs['bbox_image'] = cond_img
1822
+ return logs
1823
+
1824
+ class LatentInpaintDiffusion(LatentDiffusion):
1825
+ def __init__(
1826
+ self,
1827
+ concat_keys=("mask", "masked_image"),
1828
+ masked_image_key="masked_image",
1829
+ finetune_keys=None,
1830
+ *args,
1831
+ **kwargs,
1832
+ ):
1833
+ super().__init__(*args, **kwargs)
1834
+ self.masked_image_key = masked_image_key
1835
+ assert self.masked_image_key in concat_keys
1836
+ self.concat_keys = concat_keys
1837
+
1838
+
1839
+ @torch.no_grad()
1840
+ def get_input(
1841
+ self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
1842
+ ):
1843
+ # note: restricted to non-trainable encoders currently
1844
+ assert (
1845
+ not self.cond_stage_trainable
1846
+ ), "trainable cond stages not yet supported for inpainting"
1847
+ z, c, x, xrec, xc = super().get_input(
1848
+ batch,
1849
+ self.first_stage_key,
1850
+ return_first_stage_outputs=True,
1851
+ force_c_encode=True,
1852
+ return_original_cond=True,
1853
+ bs=bs,
1854
+ )
1855
+
1856
+ assert exists(self.concat_keys)
1857
+ c_cat = list()
1858
+ for ck in self.concat_keys:
1859
+ cc = (
1860
+ rearrange(batch[ck], "b h w c -> b c h w")
1861
+ .to(memory_format=torch.contiguous_format)
1862
+ .float()
1863
+ )
1864
+ if bs is not None:
1865
+ cc = cc[:bs]
1866
+ cc = cc.to(self.device)
1867
+ bchw = z.shape
1868
+ if ck != self.masked_image_key:
1869
+ cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
1870
+ else:
1871
+ cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
1872
+ c_cat.append(cc)
1873
+ c_cat = torch.cat(c_cat, dim=1)
1874
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
1875
+ if return_first_stage_outputs:
1876
+ return z, all_conds, x, xrec, xc
1877
+ return z, all_conds
ldm/models/diffusion/plms.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ """SAMPLING ONLY."""
15
+
16
+ import torch
17
+ import numpy as np
18
+ from tqdm import tqdm
19
+ from functools import partial
20
+
21
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
22
+
23
+
24
+ class PLMSSampler(object):
25
+ def __init__(self, model, schedule="linear", **kwargs):
26
+ super().__init__()
27
+ self.model = model
28
+ self.ddpm_num_timesteps = model.num_timesteps
29
+ self.schedule = schedule
30
+
31
+ def register_buffer(self, name, attr):
32
+ if type(attr) == torch.Tensor:
33
+ if attr.device != torch.device("cuda"):
34
+ attr = attr.to(torch.device("cuda"))
35
+ setattr(self, name, attr)
36
+
37
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
38
+ if ddim_eta != 0:
39
+ raise ValueError('ddim_eta must be 0 for PLMS')
40
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
41
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
42
+ alphas_cumprod = self.model.alphas_cumprod
43
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
44
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
45
+
46
+ self.register_buffer('betas', to_torch(self.model.betas))
47
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
48
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
49
+
50
+ # calculations for diffusion q(x_t | x_{t-1}) and others
51
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
52
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
53
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
54
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
55
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
56
+
57
+ # ddim sampling parameters
58
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
59
+ ddim_timesteps=self.ddim_timesteps,
60
+ eta=ddim_eta,verbose=verbose)
61
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
62
+ self.register_buffer('ddim_alphas', ddim_alphas)
63
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
64
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
65
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
66
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
67
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
68
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
69
+
70
+ @torch.no_grad()
71
+ def sample(self,
72
+ S,
73
+ batch_size,
74
+ shape,
75
+ conditioning=None,
76
+ callback=None,
77
+ normals_sequence=None,
78
+ img_callback=None,
79
+ quantize_x0=False,
80
+ eta=0.,
81
+ mask=None,
82
+ x0=None,
83
+ temperature=1.,
84
+ noise_dropout=0.,
85
+ score_corrector=None,
86
+ corrector_kwargs=None,
87
+ verbose=True,
88
+ x_T=None,
89
+ log_every_t=100,
90
+ unconditional_guidance_scale=1.,
91
+ unconditional_conditioning=None,
92
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
93
+ **kwargs
94
+ ):
95
+ if conditioning is not None:
96
+ if isinstance(conditioning, dict):
97
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
98
+ if cbs != batch_size:
99
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
100
+ else:
101
+ if conditioning.shape[0] != batch_size:
102
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
103
+
104
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
105
+ # sampling
106
+ C, H, W = shape
107
+ size = (batch_size, C, H, W)
108
+ print(f'Data shape for PLMS sampling is {size}')
109
+
110
+ samples, intermediates = self.plms_sampling(conditioning, size,
111
+ callback=callback,
112
+ img_callback=img_callback,
113
+ quantize_denoised=quantize_x0,
114
+ mask=mask, x0=x0,
115
+ ddim_use_original_steps=False,
116
+ noise_dropout=noise_dropout,
117
+ temperature=temperature,
118
+ score_corrector=score_corrector,
119
+ corrector_kwargs=corrector_kwargs,
120
+ x_T=x_T,
121
+ log_every_t=log_every_t,
122
+ unconditional_guidance_scale=unconditional_guidance_scale,
123
+ unconditional_conditioning=unconditional_conditioning,
124
+ **kwargs
125
+ )
126
+ return samples, intermediates
127
+
128
+ @torch.no_grad()
129
+ def plms_sampling(self, cond, shape,
130
+ x_T=None, ddim_use_original_steps=False,
131
+ callback=None, timesteps=None, quantize_denoised=False,
132
+ mask=None, x0=None, img_callback=None, log_every_t=100,
133
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
134
+ unconditional_guidance_scale=1., unconditional_conditioning=None,**kwargs):
135
+ device = self.model.betas.device
136
+ b = shape[0]
137
+ if x_T is None:
138
+ img = torch.randn(shape, device=device)
139
+ else:
140
+ img = x_T
141
+
142
+ if timesteps is None:
143
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
144
+ elif timesteps is not None and not ddim_use_original_steps:
145
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
146
+ timesteps = self.ddim_timesteps[:subset_end]
147
+
148
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
149
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
150
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
151
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
152
+
153
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
154
+ old_eps = []
155
+
156
+ for i, step in enumerate(iterator):
157
+ index = total_steps - i - 1
158
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
159
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
160
+
161
+ if mask is not None:
162
+ assert x0 is not None
163
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
164
+ img = img_orig * mask + (1. - mask) * img
165
+
166
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
167
+ quantize_denoised=quantize_denoised, temperature=temperature,
168
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
169
+ corrector_kwargs=corrector_kwargs,
170
+ unconditional_guidance_scale=unconditional_guidance_scale,
171
+ unconditional_conditioning=unconditional_conditioning,
172
+ old_eps=old_eps, t_next=ts_next,**kwargs)
173
+ img, pred_x0, e_t = outs
174
+ old_eps.append(e_t)
175
+ if len(old_eps) >= 4:
176
+ old_eps.pop(0)
177
+ if callback: callback(i)
178
+ if img_callback: img_callback(pred_x0, i)
179
+
180
+ if index % log_every_t == 0 or index == total_steps - 1:
181
+ intermediates['x_inter'].append(img)
182
+ intermediates['pred_x0'].append(pred_x0)
183
+
184
+ return img, intermediates
185
+
186
+ @torch.no_grad()
187
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
188
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
189
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,**kwargs):
190
+ b, *_, device = *x.shape, x.device
191
+ def get_model_output(x, t):
192
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
193
+ e_t = self.model.apply_model(x, t, c)
194
+ else:
195
+ x_in = torch.cat([x] * 2)
196
+ t_in = torch.cat([t] * 2)
197
+ c_in = torch.cat([unconditional_conditioning, c])
198
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
199
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
200
+
201
+ if score_corrector is not None:
202
+ assert self.model.parameterization == "eps"
203
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
204
+
205
+ return e_t
206
+
207
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
208
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
209
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
210
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
211
+
212
+ def get_x_prev_and_pred_x0(e_t, index):
213
+ # select parameters corresponding to the currently considered timestep
214
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
215
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
216
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
217
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
218
+
219
+ # current prediction for x_0
220
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
221
+ if quantize_denoised:
222
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
223
+ # direction pointing to x_t
224
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
225
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
226
+ if noise_dropout > 0.:
227
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
228
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
229
+ return x_prev, pred_x0
230
+ kwargs=kwargs['test_model_kwargs']
231
+ x_new=torch.cat([x,kwargs['inpaint_image'],kwargs['inpaint_mask']],dim=1)
232
+ e_t = get_model_output(x_new, t)
233
+ if len(old_eps) == 0:
234
+ # Pseudo Improved Euler (2nd order)
235
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
236
+ x_prev_new=torch.cat([x_prev,kwargs['inpaint_image'],kwargs['inpaint_mask']],dim=1)
237
+ e_t_next = get_model_output(x_prev_new, t_next)
238
+ e_t_prime = (e_t + e_t_next) / 2
239
+ elif len(old_eps) == 1:
240
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
241
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
242
+ elif len(old_eps) == 2:
243
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
244
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
245
+ elif len(old_eps) >= 3:
246
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
247
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
248
+
249
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
250
+
251
+ return x_prev, pred_x0, e_t
ldm/modules/attention.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ from inspect import isfunction
15
+ import math
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn, einsum
20
+ from einops import rearrange, repeat
21
+ import glob
22
+
23
+ from ldm.modules.diffusionmodules.util import checkpoint
24
+
25
+
26
+ def exists(val):
27
+ return val is not None
28
+
29
+
30
+ def uniq(arr):
31
+ return{el: True for el in arr}.keys()
32
+
33
+
34
+ def default(val, d):
35
+ if exists(val):
36
+ return val
37
+ return d() if isfunction(d) else d
38
+
39
+
40
+ def max_neg_value(t):
41
+ return -torch.finfo(t.dtype).max
42
+
43
+
44
+ def init_(tensor):
45
+ dim = tensor.shape[-1]
46
+ std = 1 / math.sqrt(dim)
47
+ tensor.uniform_(-std, std)
48
+ return tensor
49
+
50
+
51
+ # feedforward
52
+ class GEGLU(nn.Module):
53
+ def __init__(self, dim_in, dim_out):
54
+ super().__init__()
55
+ self.proj = nn.Linear(dim_in, dim_out * 2)
56
+
57
+ def forward(self, x):
58
+ x, gate = self.proj(x).chunk(2, dim=-1)
59
+ return x * F.gelu(gate)
60
+
61
+
62
+ class FeedForward(nn.Module):
63
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
64
+ super().__init__()
65
+ inner_dim = int(dim * mult)
66
+ dim_out = default(dim_out, dim)
67
+ project_in = nn.Sequential(
68
+ nn.Linear(dim, inner_dim),
69
+ nn.GELU()
70
+ ) if not glu else GEGLU(dim, inner_dim)
71
+
72
+ self.net = nn.Sequential(
73
+ project_in,
74
+ nn.Dropout(dropout),
75
+ nn.Linear(inner_dim, dim_out)
76
+ )
77
+
78
+ def forward(self, x):
79
+ return self.net(x)
80
+
81
+
82
+ def zero_module(module):
83
+ """
84
+ Zero out the parameters of a module and return it.
85
+ """
86
+ for p in module.parameters():
87
+ p.detach().zero_()
88
+ return module
89
+
90
+
91
+ def Normalize(in_channels):
92
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
93
+
94
+
95
+ class LinearAttention(nn.Module):
96
+ def __init__(self, dim, heads=4, dim_head=32):
97
+ super().__init__()
98
+ self.heads = heads
99
+ hidden_dim = dim_head * heads
100
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
101
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
102
+
103
+ def forward(self, x):
104
+ b, c, h, w = x.shape
105
+ qkv = self.to_qkv(x)
106
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
107
+ k = k.softmax(dim=-1)
108
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
109
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
110
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
111
+ return self.to_out(out)
112
+
113
+
114
+ class SpatialSelfAttention(nn.Module):
115
+ def __init__(self, in_channels):
116
+ super().__init__()
117
+ self.in_channels = in_channels
118
+
119
+ self.norm = Normalize(in_channels)
120
+ self.q = torch.nn.Conv2d(in_channels,
121
+ in_channels,
122
+ kernel_size=1,
123
+ stride=1,
124
+ padding=0)
125
+ self.k = torch.nn.Conv2d(in_channels,
126
+ in_channels,
127
+ kernel_size=1,
128
+ stride=1,
129
+ padding=0)
130
+ self.v = torch.nn.Conv2d(in_channels,
131
+ in_channels,
132
+ kernel_size=1,
133
+ stride=1,
134
+ padding=0)
135
+ self.proj_out = torch.nn.Conv2d(in_channels,
136
+ in_channels,
137
+ kernel_size=1,
138
+ stride=1,
139
+ padding=0)
140
+
141
+ def forward(self, x):
142
+ h_ = x
143
+ h_ = self.norm(h_)
144
+ q = self.q(h_)
145
+ k = self.k(h_)
146
+ v = self.v(h_)
147
+
148
+ # compute attention
149
+ b,c,h,w = q.shape
150
+ q = rearrange(q, 'b c h w -> b (h w) c')
151
+ k = rearrange(k, 'b c h w -> b c (h w)')
152
+ w_ = torch.einsum('bij,bjk->bik', q, k)
153
+
154
+ w_ = w_ * (int(c)**(-0.5))
155
+ w_ = torch.nn.functional.softmax(w_, dim=2)
156
+
157
+ # attend to values
158
+ v = rearrange(v, 'b c h w -> b c (h w)')
159
+ w_ = rearrange(w_, 'b i j -> b j i')
160
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
161
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
162
+ h_ = self.proj_out(h_)
163
+
164
+ return x+h_
165
+
166
+
167
+ class CrossAttention(nn.Module):
168
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., only_crossref=False):
169
+ super().__init__()
170
+ inner_dim = dim_head * heads
171
+ # forcing attention to only attend on vectors of same size
172
+ # breaking the image2text attention
173
+ context_dim = default(context_dim, query_dim)
174
+
175
+ # print('creating cross attention. Query dim', query_dim, ' context dim', context_dim)
176
+
177
+ self.scale = dim_head ** -0.5
178
+ self.heads = heads
179
+
180
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
181
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
182
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
183
+
184
+ self.to_out = nn.Sequential(
185
+ nn.Linear(inner_dim, query_dim),
186
+ nn.Dropout(dropout)
187
+ )
188
+
189
+ self.only_crossref = only_crossref
190
+ if only_crossref:
191
+ self.merge_attentions = zero_module(nn.Conv2d(self.heads * 2,
192
+ self.heads,
193
+ kernel_size=1,
194
+ stride=1,
195
+ padding=0))
196
+ else:
197
+ self.merge_attentions = zero_module(nn.Conv2d(self.heads * 3,
198
+ self.heads,
199
+ kernel_size=1,
200
+ stride=1,
201
+ padding=0))
202
+
203
+
204
+ self.merge_attentions_missing = zero_module(nn.Conv2d(self.heads * 2,
205
+ self.heads,
206
+ kernel_size=1,
207
+ stride=1,
208
+ padding=0))
209
+
210
+
211
+ def forward(self, x, context=None, mask=None, passed_qkv=None, masks=None, corresp=None, missing_region=None):
212
+ is_self_attention = context is None
213
+
214
+ # if masks is not None:
215
+ # print(is_self_attention, masks.keys())
216
+
217
+ h = self.heads
218
+
219
+ # if passed_qkv is not None:
220
+ # assert context is None
221
+
222
+ # _,_,_,_, x_features = passed_qkv
223
+ # assert x_features is not None
224
+
225
+ # # print('x shape', x.shape, 'x features', x_features.shape)
226
+ # # breakpoint()
227
+ # x = torch.concat([x, x_features], dim=1)
228
+
229
+ q = self.to_q(x)
230
+ context = default(context, x)
231
+ k = self.to_k(context)
232
+ v = self.to_v(context)
233
+
234
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
235
+
236
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
237
+
238
+ if exists(mask):
239
+ assert False
240
+ mask = rearrange(mask, 'b ... -> b (...)')
241
+ max_neg_value = -torch.finfo(sim.dtype).max
242
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
243
+ sim.masked_fill_(~mask, max_neg_value)
244
+
245
+ # attention, what we cannot get enough of
246
+ attn = sim.softmax(dim=-1)
247
+ out = einsum('b i j, b j d -> b i d', attn, v)
248
+ inter_out = rearrange(out, '(b h) n d -> b h n d', h=h)
249
+
250
+ combined_attention = inter_out
251
+ out = rearrange(combined_attention, 'b h n d -> b n (h d)', h=h)
252
+
253
+ final_out = self.to_out(out)
254
+
255
+ if is_self_attention:
256
+ return final_out, q, k, v, inter_out #TODO add attn out
257
+ else:
258
+ return final_out
259
+
260
+
261
+ class BasicTransformerBlock(nn.Module):
262
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
263
+ super().__init__()
264
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
265
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
266
+ self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
267
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
268
+ self.attn3 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout)
269
+ self.norm1 = nn.LayerNorm(dim)
270
+ self.norm2 = nn.LayerNorm(dim)
271
+ self.norm3 = nn.LayerNorm(dim)
272
+ self.checkpoint = checkpoint
273
+
274
+ # TODO add attn in
275
+ def forward(self, x, context=None, passed_qkv=None, masks=None, corresp=None):
276
+ if passed_qkv is None:
277
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
278
+ else:
279
+ q, k, v, attn, x_features = passed_qkv
280
+ d = int(np.sqrt(q.shape[1]))
281
+ current_mask = masks[d]
282
+ if corresp:
283
+ current_corresp, missing_region = corresp[d]
284
+ current_corresp = current_corresp.float()
285
+ missing_region = missing_region.float()
286
+ else:
287
+ raise ValueError('cannot have empty corresp')
288
+ current_corresp = None
289
+ missing_region = current_mask.float()
290
+ # breakpoint()
291
+ stuff = [q, k, v, attn, x_features, current_mask, current_corresp, missing_region]
292
+ for element in stuff:
293
+ assert element is not None
294
+ return checkpoint(self._forward, (x, context, q, k, v, attn, x_features, current_mask, current_corresp, missing_region), self.parameters(), self.checkpoint)
295
+
296
+ # TODO add attn in
297
+ def _forward(self, x, context=None, q=None, k=None, v=None, attn=None, passed_x=None, masks=None, corresp=None, missing_region=None):
298
+ if q is not None:
299
+ passed_qkv = (q, k, v, attn, passed_x)
300
+ else:
301
+ passed_qkv = None
302
+ x_features = self.norm1(x)
303
+ attended_x, q, k, v, attn = self.attn1(x_features, passed_qkv=passed_qkv, masks=masks, corresp=corresp, missing_region=missing_region)
304
+ x = attended_x + x
305
+ # killing CLIP features
306
+
307
+ if passed_x is not None:
308
+ normed_x = self.norm2(x)
309
+ attn_out = self.attn3(normed_x, context=passed_x)
310
+ x = attn_out + x
311
+ # then use y + x
312
+ # print('y shape', y.shape, ' x shape', x.shape)
313
+
314
+ x = self.ff(self.norm3(x)) + x
315
+ return x, q, k, v, attn, x_features
316
+
317
+
318
+ class SpatialTransformer(nn.Module):
319
+ """
320
+ Transformer block for image-like data.
321
+ First, project the input (aka embedding)
322
+ and reshape to b, t, d.
323
+ Then apply standard transformer action.
324
+ Finally, reshape to image
325
+ """
326
+ def __init__(self, in_channels, n_heads, d_head,
327
+ depth=1, dropout=0., context_dim=None):
328
+ super().__init__()
329
+ self.in_channels = in_channels
330
+ inner_dim = n_heads * d_head
331
+ self.norm = Normalize(in_channels)
332
+
333
+ # print('creating spatial transformer')
334
+ # print('in channels', in_channels, 'inner dim', inner_dim)
335
+
336
+ self.proj_in = nn.Conv2d(in_channels,
337
+ inner_dim,
338
+ kernel_size=1,
339
+ stride=1,
340
+ padding=0)
341
+
342
+ self.transformer_blocks = nn.ModuleList(
343
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
344
+ for d in range(depth)]
345
+ )
346
+
347
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
348
+ in_channels,
349
+ kernel_size=1,
350
+ stride=1,
351
+ padding=0))
352
+
353
+ # TODO add attn in and corresp
354
+ def forward(self, x, context=None, passed_qkv=None, masks=None, corresp=None):
355
+ # note: if no context is given, cross-attention defaults to self-attention
356
+ b, c, h, w = x.shape
357
+ # print('spatial transformer x shape given', x.shape)
358
+ # if context is not None:
359
+ # print('also context was provided with shape ', context.shape)
360
+ x_in = x
361
+ x = self.norm(x)
362
+ x = self.proj_in(x)
363
+ x = rearrange(x, 'b c h w -> b (h w) c')
364
+
365
+ qkvs = []
366
+ for block in self.transformer_blocks:
367
+ x, q, k, v, attn, x_features = block(x, context=context, passed_qkv=passed_qkv, masks=masks, corresp=corresp)
368
+ qkv = (q,k,v,attn, x_features)
369
+ qkvs.append(qkv)
370
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
371
+ x = self.proj_out(x)
372
+ return x + x_in, qkvs
ldm/modules/diffusionmodules/__init__.py ADDED
File without changes
ldm/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,848 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ # pytorch_diffusion + derived encoder decoder
15
+ import math
16
+ import torch
17
+ import torch.nn as nn
18
+ import numpy as np
19
+ from einops import rearrange
20
+
21
+ from ldm.util import instantiate_from_config
22
+ from ldm.modules.attention import LinearAttention
23
+
24
+
25
+ def get_timestep_embedding(timesteps, embedding_dim):
26
+ """
27
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
28
+ From Fairseq.
29
+ Build sinusoidal embeddings.
30
+ This matches the implementation in tensor2tensor, but differs slightly
31
+ from the description in Section 3.5 of "Attention Is All You Need".
32
+ """
33
+ assert len(timesteps.shape) == 1
34
+
35
+ half_dim = embedding_dim // 2
36
+ emb = math.log(10000) / (half_dim - 1)
37
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
38
+ emb = emb.to(device=timesteps.device)
39
+ emb = timesteps.float()[:, None] * emb[None, :]
40
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
41
+ if embedding_dim % 2 == 1: # zero pad
42
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
43
+ return emb
44
+
45
+
46
+ def nonlinearity(x):
47
+ # swish
48
+ return x*torch.sigmoid(x)
49
+
50
+
51
+ def Normalize(in_channels, num_groups=32):
52
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
53
+
54
+
55
+ class Upsample(nn.Module):
56
+ def __init__(self, in_channels, with_conv):
57
+ super().__init__()
58
+ self.with_conv = with_conv
59
+ if self.with_conv:
60
+ self.conv = torch.nn.Conv2d(in_channels,
61
+ in_channels,
62
+ kernel_size=3,
63
+ stride=1,
64
+ padding=1)
65
+
66
+ def forward(self, x):
67
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
68
+ if self.with_conv:
69
+ x = self.conv(x)
70
+ return x
71
+
72
+
73
+ class Downsample(nn.Module):
74
+ def __init__(self, in_channels, with_conv):
75
+ super().__init__()
76
+ self.with_conv = with_conv
77
+ if self.with_conv:
78
+ # no asymmetric padding in torch conv, must do it ourselves
79
+ self.conv = torch.nn.Conv2d(in_channels,
80
+ in_channels,
81
+ kernel_size=3,
82
+ stride=2,
83
+ padding=0)
84
+
85
+ def forward(self, x):
86
+ if self.with_conv:
87
+ pad = (0,1,0,1)
88
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
89
+ x = self.conv(x)
90
+ else:
91
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
92
+ return x
93
+
94
+
95
+ class ResnetBlock(nn.Module):
96
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
97
+ dropout, temb_channels=512):
98
+ super().__init__()
99
+ self.in_channels = in_channels
100
+ out_channels = in_channels if out_channels is None else out_channels
101
+ self.out_channels = out_channels
102
+ self.use_conv_shortcut = conv_shortcut
103
+
104
+ self.norm1 = Normalize(in_channels)
105
+ self.conv1 = torch.nn.Conv2d(in_channels,
106
+ out_channels,
107
+ kernel_size=3,
108
+ stride=1,
109
+ padding=1)
110
+ if temb_channels > 0:
111
+ self.temb_proj = torch.nn.Linear(temb_channels,
112
+ out_channels)
113
+ self.norm2 = Normalize(out_channels)
114
+ self.dropout = torch.nn.Dropout(dropout)
115
+ self.conv2 = torch.nn.Conv2d(out_channels,
116
+ out_channels,
117
+ kernel_size=3,
118
+ stride=1,
119
+ padding=1)
120
+ if self.in_channels != self.out_channels:
121
+ if self.use_conv_shortcut:
122
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
123
+ out_channels,
124
+ kernel_size=3,
125
+ stride=1,
126
+ padding=1)
127
+ else:
128
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
129
+ out_channels,
130
+ kernel_size=1,
131
+ stride=1,
132
+ padding=0)
133
+
134
+ def forward(self, x, temb):
135
+ h = x
136
+ h = self.norm1(h)
137
+ h = nonlinearity(h)
138
+ h = self.conv1(h)
139
+
140
+ if temb is not None:
141
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
142
+
143
+ h = self.norm2(h)
144
+ h = nonlinearity(h)
145
+ h = self.dropout(h)
146
+ h = self.conv2(h)
147
+
148
+ if self.in_channels != self.out_channels:
149
+ if self.use_conv_shortcut:
150
+ x = self.conv_shortcut(x)
151
+ else:
152
+ x = self.nin_shortcut(x)
153
+
154
+ return x+h
155
+
156
+
157
+ class LinAttnBlock(LinearAttention):
158
+ """to match AttnBlock usage"""
159
+ def __init__(self, in_channels):
160
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
161
+
162
+
163
+ class AttnBlock(nn.Module):
164
+ def __init__(self, in_channels):
165
+ super().__init__()
166
+ self.in_channels = in_channels
167
+
168
+ self.norm = Normalize(in_channels)
169
+ self.q = torch.nn.Conv2d(in_channels,
170
+ in_channels,
171
+ kernel_size=1,
172
+ stride=1,
173
+ padding=0)
174
+ self.k = torch.nn.Conv2d(in_channels,
175
+ in_channels,
176
+ kernel_size=1,
177
+ stride=1,
178
+ padding=0)
179
+ self.v = torch.nn.Conv2d(in_channels,
180
+ in_channels,
181
+ kernel_size=1,
182
+ stride=1,
183
+ padding=0)
184
+ self.proj_out = torch.nn.Conv2d(in_channels,
185
+ in_channels,
186
+ kernel_size=1,
187
+ stride=1,
188
+ padding=0)
189
+
190
+
191
+ def forward(self, x):
192
+ h_ = x
193
+ h_ = self.norm(h_)
194
+ q = self.q(h_)
195
+ k = self.k(h_)
196
+ v = self.v(h_)
197
+
198
+ # compute attention
199
+ b,c,h,w = q.shape
200
+ q = q.reshape(b,c,h*w)
201
+ q = q.permute(0,2,1) # b,hw,c
202
+ k = k.reshape(b,c,h*w) # b,c,hw
203
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
204
+ w_ = w_ * (int(c)**(-0.5))
205
+ w_ = torch.nn.functional.softmax(w_, dim=2)
206
+
207
+ # attend to values
208
+ v = v.reshape(b,c,h*w)
209
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
210
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
211
+ h_ = h_.reshape(b,c,h,w)
212
+
213
+ h_ = self.proj_out(h_)
214
+
215
+ return x+h_
216
+
217
+
218
+ def make_attn(in_channels, attn_type="vanilla"):
219
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
220
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
221
+ if attn_type == "vanilla":
222
+ return AttnBlock(in_channels)
223
+ elif attn_type == "none":
224
+ return nn.Identity(in_channels)
225
+ else:
226
+ return LinAttnBlock(in_channels)
227
+
228
+
229
+ class Model(nn.Module):
230
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
231
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
232
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
233
+ super().__init__()
234
+ if use_linear_attn: attn_type = "linear"
235
+ self.ch = ch
236
+ self.temb_ch = self.ch*4
237
+ self.num_resolutions = len(ch_mult)
238
+ self.num_res_blocks = num_res_blocks
239
+ self.resolution = resolution
240
+ self.in_channels = in_channels
241
+
242
+ self.use_timestep = use_timestep
243
+ if self.use_timestep:
244
+ # timestep embedding
245
+ self.temb = nn.Module()
246
+ self.temb.dense = nn.ModuleList([
247
+ torch.nn.Linear(self.ch,
248
+ self.temb_ch),
249
+ torch.nn.Linear(self.temb_ch,
250
+ self.temb_ch),
251
+ ])
252
+
253
+ # downsampling
254
+ self.conv_in = torch.nn.Conv2d(in_channels,
255
+ self.ch,
256
+ kernel_size=3,
257
+ stride=1,
258
+ padding=1)
259
+
260
+ curr_res = resolution
261
+ in_ch_mult = (1,)+tuple(ch_mult)
262
+ self.down = nn.ModuleList()
263
+ for i_level in range(self.num_resolutions):
264
+ block = nn.ModuleList()
265
+ attn = nn.ModuleList()
266
+ block_in = ch*in_ch_mult[i_level]
267
+ block_out = ch*ch_mult[i_level]
268
+ for i_block in range(self.num_res_blocks):
269
+ block.append(ResnetBlock(in_channels=block_in,
270
+ out_channels=block_out,
271
+ temb_channels=self.temb_ch,
272
+ dropout=dropout))
273
+ block_in = block_out
274
+ if curr_res in attn_resolutions:
275
+ attn.append(make_attn(block_in, attn_type=attn_type))
276
+ down = nn.Module()
277
+ down.block = block
278
+ down.attn = attn
279
+ if i_level != self.num_resolutions-1:
280
+ down.downsample = Downsample(block_in, resamp_with_conv)
281
+ curr_res = curr_res // 2
282
+ self.down.append(down)
283
+
284
+ # middle
285
+ self.mid = nn.Module()
286
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
287
+ out_channels=block_in,
288
+ temb_channels=self.temb_ch,
289
+ dropout=dropout)
290
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
291
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
292
+ out_channels=block_in,
293
+ temb_channels=self.temb_ch,
294
+ dropout=dropout)
295
+
296
+ # upsampling
297
+ self.up = nn.ModuleList()
298
+ for i_level in reversed(range(self.num_resolutions)):
299
+ block = nn.ModuleList()
300
+ attn = nn.ModuleList()
301
+ block_out = ch*ch_mult[i_level]
302
+ skip_in = ch*ch_mult[i_level]
303
+ for i_block in range(self.num_res_blocks+1):
304
+ if i_block == self.num_res_blocks:
305
+ skip_in = ch*in_ch_mult[i_level]
306
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
307
+ out_channels=block_out,
308
+ temb_channels=self.temb_ch,
309
+ dropout=dropout))
310
+ block_in = block_out
311
+ if curr_res in attn_resolutions:
312
+ attn.append(make_attn(block_in, attn_type=attn_type))
313
+ up = nn.Module()
314
+ up.block = block
315
+ up.attn = attn
316
+ if i_level != 0:
317
+ up.upsample = Upsample(block_in, resamp_with_conv)
318
+ curr_res = curr_res * 2
319
+ self.up.insert(0, up) # prepend to get consistent order
320
+
321
+ # end
322
+ self.norm_out = Normalize(block_in)
323
+ self.conv_out = torch.nn.Conv2d(block_in,
324
+ out_ch,
325
+ kernel_size=3,
326
+ stride=1,
327
+ padding=1)
328
+
329
+ def forward(self, x, t=None, context=None):
330
+ #assert x.shape[2] == x.shape[3] == self.resolution
331
+ if context is not None:
332
+ # assume aligned context, cat along channel axis
333
+ x = torch.cat((x, context), dim=1)
334
+ if self.use_timestep:
335
+ # timestep embedding
336
+ assert t is not None
337
+ temb = get_timestep_embedding(t, self.ch)
338
+ temb = self.temb.dense[0](temb)
339
+ temb = nonlinearity(temb)
340
+ temb = self.temb.dense[1](temb)
341
+ else:
342
+ temb = None
343
+
344
+ # downsampling
345
+ hs = [self.conv_in(x)]
346
+ for i_level in range(self.num_resolutions):
347
+ for i_block in range(self.num_res_blocks):
348
+ h = self.down[i_level].block[i_block](hs[-1], temb)
349
+ if len(self.down[i_level].attn) > 0:
350
+ h = self.down[i_level].attn[i_block](h)
351
+ hs.append(h)
352
+ if i_level != self.num_resolutions-1:
353
+ hs.append(self.down[i_level].downsample(hs[-1]))
354
+
355
+ # middle
356
+ h = hs[-1]
357
+ h = self.mid.block_1(h, temb)
358
+ h = self.mid.attn_1(h)
359
+ h = self.mid.block_2(h, temb)
360
+
361
+ # upsampling
362
+ for i_level in reversed(range(self.num_resolutions)):
363
+ for i_block in range(self.num_res_blocks+1):
364
+ h = self.up[i_level].block[i_block](
365
+ torch.cat([h, hs.pop()], dim=1), temb)
366
+ if len(self.up[i_level].attn) > 0:
367
+ h = self.up[i_level].attn[i_block](h)
368
+ if i_level != 0:
369
+ h = self.up[i_level].upsample(h)
370
+
371
+ # end
372
+ h = self.norm_out(h)
373
+ h = nonlinearity(h)
374
+ h = self.conv_out(h)
375
+ return h
376
+
377
+ def get_last_layer(self):
378
+ return self.conv_out.weight
379
+
380
+
381
+ class Encoder(nn.Module):
382
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
383
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
384
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
385
+ **ignore_kwargs):
386
+ super().__init__()
387
+ if use_linear_attn: attn_type = "linear"
388
+ self.ch = ch
389
+ self.temb_ch = 0
390
+ self.num_resolutions = len(ch_mult)
391
+ self.num_res_blocks = num_res_blocks
392
+ self.resolution = resolution
393
+ self.in_channels = in_channels
394
+
395
+ # downsampling
396
+ self.conv_in = torch.nn.Conv2d(in_channels,
397
+ self.ch,
398
+ kernel_size=3,
399
+ stride=1,
400
+ padding=1)
401
+
402
+ curr_res = resolution
403
+ in_ch_mult = (1,)+tuple(ch_mult)
404
+ self.in_ch_mult = in_ch_mult
405
+ self.down = nn.ModuleList()
406
+ for i_level in range(self.num_resolutions):
407
+ block = nn.ModuleList()
408
+ attn = nn.ModuleList()
409
+ block_in = ch*in_ch_mult[i_level]
410
+ block_out = ch*ch_mult[i_level]
411
+ for i_block in range(self.num_res_blocks):
412
+ block.append(ResnetBlock(in_channels=block_in,
413
+ out_channels=block_out,
414
+ temb_channels=self.temb_ch,
415
+ dropout=dropout))
416
+ block_in = block_out
417
+ if curr_res in attn_resolutions:
418
+ attn.append(make_attn(block_in, attn_type=attn_type))
419
+ down = nn.Module()
420
+ down.block = block
421
+ down.attn = attn
422
+ if i_level != self.num_resolutions-1:
423
+ down.downsample = Downsample(block_in, resamp_with_conv)
424
+ curr_res = curr_res // 2
425
+ self.down.append(down)
426
+
427
+ # middle
428
+ self.mid = nn.Module()
429
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
430
+ out_channels=block_in,
431
+ temb_channels=self.temb_ch,
432
+ dropout=dropout)
433
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
434
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
435
+ out_channels=block_in,
436
+ temb_channels=self.temb_ch,
437
+ dropout=dropout)
438
+
439
+ # end
440
+ self.norm_out = Normalize(block_in)
441
+ self.conv_out = torch.nn.Conv2d(block_in,
442
+ 2*z_channels if double_z else z_channels,
443
+ kernel_size=3,
444
+ stride=1,
445
+ padding=1)
446
+
447
+ def forward(self, x):
448
+ # timestep embedding
449
+ temb = None
450
+
451
+ # downsampling
452
+ hs = [self.conv_in(x)]
453
+ for i_level in range(self.num_resolutions):
454
+ for i_block in range(self.num_res_blocks):
455
+ h = self.down[i_level].block[i_block](hs[-1], temb)
456
+ if len(self.down[i_level].attn) > 0:
457
+ h = self.down[i_level].attn[i_block](h)
458
+ hs.append(h)
459
+ if i_level != self.num_resolutions-1:
460
+ hs.append(self.down[i_level].downsample(hs[-1]))
461
+
462
+ # middle
463
+ h = hs[-1]
464
+ h = self.mid.block_1(h, temb)
465
+ h = self.mid.attn_1(h)
466
+ h = self.mid.block_2(h, temb)
467
+
468
+ # end
469
+ h = self.norm_out(h)
470
+ h = nonlinearity(h)
471
+ h = self.conv_out(h)
472
+ return h
473
+
474
+
475
+ class Decoder(nn.Module):
476
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
477
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
478
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
479
+ attn_type="vanilla", **ignorekwargs):
480
+ super().__init__()
481
+ if use_linear_attn: attn_type = "linear"
482
+ self.ch = ch
483
+ self.temb_ch = 0
484
+ self.num_resolutions = len(ch_mult)
485
+ self.num_res_blocks = num_res_blocks
486
+ self.resolution = resolution
487
+ self.in_channels = in_channels
488
+ self.give_pre_end = give_pre_end
489
+ self.tanh_out = tanh_out
490
+
491
+ # compute in_ch_mult, block_in and curr_res at lowest res
492
+ in_ch_mult = (1,)+tuple(ch_mult)
493
+ block_in = ch*ch_mult[self.num_resolutions-1]
494
+ curr_res = resolution // 2**(self.num_resolutions-1)
495
+ self.z_shape = (1,z_channels,curr_res,curr_res)
496
+ print("Working with z of shape {} = {} dimensions.".format(
497
+ self.z_shape, np.prod(self.z_shape)))
498
+
499
+ # z to block_in
500
+ self.conv_in = torch.nn.Conv2d(z_channels,
501
+ block_in,
502
+ kernel_size=3,
503
+ stride=1,
504
+ padding=1)
505
+
506
+ # middle
507
+ self.mid = nn.Module()
508
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
509
+ out_channels=block_in,
510
+ temb_channels=self.temb_ch,
511
+ dropout=dropout)
512
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
513
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
514
+ out_channels=block_in,
515
+ temb_channels=self.temb_ch,
516
+ dropout=dropout)
517
+
518
+ # upsampling
519
+ self.up = nn.ModuleList()
520
+ for i_level in reversed(range(self.num_resolutions)):
521
+ block = nn.ModuleList()
522
+ attn = nn.ModuleList()
523
+ block_out = ch*ch_mult[i_level]
524
+ for i_block in range(self.num_res_blocks+1):
525
+ block.append(ResnetBlock(in_channels=block_in,
526
+ out_channels=block_out,
527
+ temb_channels=self.temb_ch,
528
+ dropout=dropout))
529
+ block_in = block_out
530
+ if curr_res in attn_resolutions:
531
+ attn.append(make_attn(block_in, attn_type=attn_type))
532
+ up = nn.Module()
533
+ up.block = block
534
+ up.attn = attn
535
+ if i_level != 0:
536
+ up.upsample = Upsample(block_in, resamp_with_conv)
537
+ curr_res = curr_res * 2
538
+ self.up.insert(0, up) # prepend to get consistent order
539
+
540
+ # end
541
+ self.norm_out = Normalize(block_in)
542
+ self.conv_out = torch.nn.Conv2d(block_in,
543
+ out_ch,
544
+ kernel_size=3,
545
+ stride=1,
546
+ padding=1)
547
+
548
+ def forward(self, z):
549
+ #assert z.shape[1:] == self.z_shape[1:]
550
+ self.last_z_shape = z.shape
551
+
552
+ # timestep embedding
553
+ temb = None
554
+
555
+ # z to block_in
556
+ h = self.conv_in(z)
557
+
558
+ # middle
559
+ h = self.mid.block_1(h, temb)
560
+ h = self.mid.attn_1(h)
561
+ h = self.mid.block_2(h, temb)
562
+
563
+ # upsampling
564
+ for i_level in reversed(range(self.num_resolutions)):
565
+ for i_block in range(self.num_res_blocks+1):
566
+ h = self.up[i_level].block[i_block](h, temb)
567
+ if len(self.up[i_level].attn) > 0:
568
+ h = self.up[i_level].attn[i_block](h)
569
+ if i_level != 0:
570
+ h = self.up[i_level].upsample(h)
571
+
572
+ # end
573
+ if self.give_pre_end:
574
+ return h
575
+
576
+ h = self.norm_out(h)
577
+ h = nonlinearity(h)
578
+ h = self.conv_out(h)
579
+ if self.tanh_out:
580
+ h = torch.tanh(h)
581
+ return h
582
+
583
+
584
+ class SimpleDecoder(nn.Module):
585
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
586
+ super().__init__()
587
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
588
+ ResnetBlock(in_channels=in_channels,
589
+ out_channels=2 * in_channels,
590
+ temb_channels=0, dropout=0.0),
591
+ ResnetBlock(in_channels=2 * in_channels,
592
+ out_channels=4 * in_channels,
593
+ temb_channels=0, dropout=0.0),
594
+ ResnetBlock(in_channels=4 * in_channels,
595
+ out_channels=2 * in_channels,
596
+ temb_channels=0, dropout=0.0),
597
+ nn.Conv2d(2*in_channels, in_channels, 1),
598
+ Upsample(in_channels, with_conv=True)])
599
+ # end
600
+ self.norm_out = Normalize(in_channels)
601
+ self.conv_out = torch.nn.Conv2d(in_channels,
602
+ out_channels,
603
+ kernel_size=3,
604
+ stride=1,
605
+ padding=1)
606
+
607
+ def forward(self, x):
608
+ for i, layer in enumerate(self.model):
609
+ if i in [1,2,3]:
610
+ x = layer(x, None)
611
+ else:
612
+ x = layer(x)
613
+
614
+ h = self.norm_out(x)
615
+ h = nonlinearity(h)
616
+ x = self.conv_out(h)
617
+ return x
618
+
619
+
620
+ class UpsampleDecoder(nn.Module):
621
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
622
+ ch_mult=(2,2), dropout=0.0):
623
+ super().__init__()
624
+ # upsampling
625
+ self.temb_ch = 0
626
+ self.num_resolutions = len(ch_mult)
627
+ self.num_res_blocks = num_res_blocks
628
+ block_in = in_channels
629
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
630
+ self.res_blocks = nn.ModuleList()
631
+ self.upsample_blocks = nn.ModuleList()
632
+ for i_level in range(self.num_resolutions):
633
+ res_block = []
634
+ block_out = ch * ch_mult[i_level]
635
+ for i_block in range(self.num_res_blocks + 1):
636
+ res_block.append(ResnetBlock(in_channels=block_in,
637
+ out_channels=block_out,
638
+ temb_channels=self.temb_ch,
639
+ dropout=dropout))
640
+ block_in = block_out
641
+ self.res_blocks.append(nn.ModuleList(res_block))
642
+ if i_level != self.num_resolutions - 1:
643
+ self.upsample_blocks.append(Upsample(block_in, True))
644
+ curr_res = curr_res * 2
645
+
646
+ # end
647
+ self.norm_out = Normalize(block_in)
648
+ self.conv_out = torch.nn.Conv2d(block_in,
649
+ out_channels,
650
+ kernel_size=3,
651
+ stride=1,
652
+ padding=1)
653
+
654
+ def forward(self, x):
655
+ # upsampling
656
+ h = x
657
+ for k, i_level in enumerate(range(self.num_resolutions)):
658
+ for i_block in range(self.num_res_blocks + 1):
659
+ h = self.res_blocks[i_level][i_block](h, None)
660
+ if i_level != self.num_resolutions - 1:
661
+ h = self.upsample_blocks[k](h)
662
+ h = self.norm_out(h)
663
+ h = nonlinearity(h)
664
+ h = self.conv_out(h)
665
+ return h
666
+
667
+
668
+ class LatentRescaler(nn.Module):
669
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
670
+ super().__init__()
671
+ # residual block, interpolate, residual block
672
+ self.factor = factor
673
+ self.conv_in = nn.Conv2d(in_channels,
674
+ mid_channels,
675
+ kernel_size=3,
676
+ stride=1,
677
+ padding=1)
678
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
679
+ out_channels=mid_channels,
680
+ temb_channels=0,
681
+ dropout=0.0) for _ in range(depth)])
682
+ self.attn = AttnBlock(mid_channels)
683
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
684
+ out_channels=mid_channels,
685
+ temb_channels=0,
686
+ dropout=0.0) for _ in range(depth)])
687
+
688
+ self.conv_out = nn.Conv2d(mid_channels,
689
+ out_channels,
690
+ kernel_size=1,
691
+ )
692
+
693
+ def forward(self, x):
694
+ x = self.conv_in(x)
695
+ for block in self.res_block1:
696
+ x = block(x, None)
697
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
698
+ x = self.attn(x)
699
+ for block in self.res_block2:
700
+ x = block(x, None)
701
+ x = self.conv_out(x)
702
+ return x
703
+
704
+
705
+ class MergedRescaleEncoder(nn.Module):
706
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
707
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
708
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
709
+ super().__init__()
710
+ intermediate_chn = ch * ch_mult[-1]
711
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
712
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
713
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
714
+ out_ch=None)
715
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
716
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
717
+
718
+ def forward(self, x):
719
+ x = self.encoder(x)
720
+ x = self.rescaler(x)
721
+ return x
722
+
723
+
724
+ class MergedRescaleDecoder(nn.Module):
725
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
726
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
727
+ super().__init__()
728
+ tmp_chn = z_channels*ch_mult[-1]
729
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
730
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
731
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
732
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
733
+ out_channels=tmp_chn, depth=rescale_module_depth)
734
+
735
+ def forward(self, x):
736
+ x = self.rescaler(x)
737
+ x = self.decoder(x)
738
+ return x
739
+
740
+
741
+ class Upsampler(nn.Module):
742
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
743
+ super().__init__()
744
+ assert out_size >= in_size
745
+ num_blocks = int(np.log2(out_size//in_size))+1
746
+ factor_up = 1.+ (out_size % in_size)
747
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
748
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
749
+ out_channels=in_channels)
750
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
751
+ attn_resolutions=[], in_channels=None, ch=in_channels,
752
+ ch_mult=[ch_mult for _ in range(num_blocks)])
753
+
754
+ def forward(self, x):
755
+ x = self.rescaler(x)
756
+ x = self.decoder(x)
757
+ return x
758
+
759
+
760
+ class Resize(nn.Module):
761
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
762
+ super().__init__()
763
+ self.with_conv = learned
764
+ self.mode = mode
765
+ if self.with_conv:
766
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
767
+ raise NotImplementedError()
768
+ assert in_channels is not None
769
+ # no asymmetric padding in torch conv, must do it ourselves
770
+ self.conv = torch.nn.Conv2d(in_channels,
771
+ in_channels,
772
+ kernel_size=4,
773
+ stride=2,
774
+ padding=1)
775
+
776
+ def forward(self, x, scale_factor=1.0):
777
+ if scale_factor==1.0:
778
+ return x
779
+ else:
780
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
781
+ return x
782
+
783
+ class FirstStagePostProcessor(nn.Module):
784
+
785
+ def __init__(self, ch_mult:list, in_channels,
786
+ pretrained_model:nn.Module=None,
787
+ reshape=False,
788
+ n_channels=None,
789
+ dropout=0.,
790
+ pretrained_config=None):
791
+ super().__init__()
792
+ if pretrained_config is None:
793
+ assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
794
+ self.pretrained_model = pretrained_model
795
+ else:
796
+ assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
797
+ self.instantiate_pretrained(pretrained_config)
798
+
799
+ self.do_reshape = reshape
800
+
801
+ if n_channels is None:
802
+ n_channels = self.pretrained_model.encoder.ch
803
+
804
+ self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
805
+ self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
806
+ stride=1,padding=1)
807
+
808
+ blocks = []
809
+ downs = []
810
+ ch_in = n_channels
811
+ for m in ch_mult:
812
+ blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
813
+ ch_in = m * n_channels
814
+ downs.append(Downsample(ch_in, with_conv=False))
815
+
816
+ self.model = nn.ModuleList(blocks)
817
+ self.downsampler = nn.ModuleList(downs)
818
+
819
+
820
+ def instantiate_pretrained(self, config):
821
+ model = instantiate_from_config(config)
822
+ self.pretrained_model = model.eval()
823
+ # self.pretrained_model.train = False
824
+ for param in self.pretrained_model.parameters():
825
+ param.requires_grad = False
826
+
827
+
828
+ @torch.no_grad()
829
+ def encode_with_pretrained(self,x):
830
+ c = self.pretrained_model.encode(x)
831
+ if isinstance(c, DiagonalGaussianDistribution):
832
+ c = c.mode()
833
+ return c
834
+
835
+ def forward(self,x):
836
+ z_fs = self.encode_with_pretrained(x)
837
+ z = self.proj_norm(z_fs)
838
+ z = self.proj(z)
839
+ z = nonlinearity(z)
840
+
841
+ for submodel, downmodel in zip(self.model,self.downsampler):
842
+ z = submodel(z,temb=None)
843
+ z = downmodel(z)
844
+
845
+ if self.do_reshape:
846
+ z = rearrange(z,'b c h w -> b (h w) c')
847
+ return z
848
+
ldm/modules/diffusionmodules/openaimodel.py ADDED
@@ -0,0 +1,1225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ from abc import abstractmethod
15
+ from functools import partial
16
+ import math
17
+ from typing import Iterable
18
+ from collections import deque
19
+
20
+ import numpy as np
21
+ import torch as th
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ import glob
25
+ import os
26
+
27
+ import torchvision
28
+
29
+ from ldm.modules.diffusionmodules.util import (
30
+ checkpoint,
31
+ conv_nd,
32
+ linear,
33
+ avg_pool_nd,
34
+ zero_module,
35
+ normalization,
36
+ timestep_embedding,
37
+ )
38
+ from ldm.modules.attention import SpatialTransformer
39
+
40
+
41
+ # dummy replace
42
+ def convert_module_to_f16(x):
43
+ pass
44
+
45
+ def convert_module_to_f32(x):
46
+ pass
47
+
48
+
49
+ ## go
50
+ class AttentionPool2d(nn.Module):
51
+ """
52
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ spacial_dim: int,
58
+ embed_dim: int,
59
+ num_heads_channels: int,
60
+ output_dim: int = None,
61
+ ):
62
+ super().__init__()
63
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
64
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
65
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
66
+ self.num_heads = embed_dim // num_heads_channels
67
+ self.attention = QKVAttention(self.num_heads)
68
+
69
+ def forward(self, x):
70
+ b, c, *_spatial = x.shape
71
+ x = x.reshape(b, c, -1) # NC(HW)
72
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
73
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
74
+ x = self.qkv_proj(x)
75
+ x = self.attention(x)
76
+ x = self.c_proj(x)
77
+ return x[:, :, 0]
78
+
79
+
80
+ class TimestepBlock(nn.Module):
81
+ """
82
+ Any module where forward() takes timestep embeddings as a second argument.
83
+ """
84
+
85
+ @abstractmethod
86
+ def forward(self, x, emb):
87
+ """
88
+ Apply the module to `x` given `emb` timestep embeddings.
89
+ """
90
+
91
+
92
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
93
+ """
94
+ A sequential module that passes timestep embeddings to the children that
95
+ support it as an extra input.
96
+ """
97
+
98
+ def forward(self, x, emb, context=None, passed_kqv=None, kqv_idx=None, masks=None, corresp=None):
99
+ attention_vals = []
100
+ # print('processing a layer')
101
+ # print('idx', kqv_idx)
102
+ for layer in self:
103
+ # print('processing a layer', layer.__class__.__name__)
104
+ if isinstance(layer, TimestepBlock):
105
+ x = layer(x, emb)
106
+ elif isinstance(layer, SpatialTransformer):
107
+ if passed_kqv is not None:
108
+ assert kqv_idx is not None
109
+ passed_item = passed_kqv[kqv_idx]
110
+ # print('pre passed item len', len(passed_item))
111
+ if len(passed_item) == 1:
112
+ passed_item = passed_item[0][0]
113
+ # print('success passed item', len(passed_item))
114
+ else:
115
+ passed_item = None
116
+ x, kqv = layer(x, context, passed_item, masks=masks, corresp=corresp)
117
+ attention_vals.append(kqv)
118
+ else:
119
+ x = layer(x)
120
+ # print('length of attn vals', len(attention_vals))
121
+ return x, attention_vals
122
+
123
+
124
+ class Upsample(nn.Module):
125
+ """
126
+ An upsampling layer with an optional convolution.
127
+ :param channels: channels in the inputs and outputs.
128
+ :param use_conv: a bool determining if a convolution is applied.
129
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
130
+ upsampling occurs in the inner-two dimensions.
131
+ """
132
+
133
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
134
+ super().__init__()
135
+ self.channels = channels
136
+ self.out_channels = out_channels or channels
137
+ self.use_conv = use_conv
138
+ self.dims = dims
139
+ if use_conv:
140
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
141
+
142
+ def forward(self, x):
143
+ assert x.shape[1] == self.channels
144
+ if self.dims == 3:
145
+ x = F.interpolate(
146
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
147
+ )
148
+ else:
149
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
150
+ if self.use_conv:
151
+ x = self.conv(x)
152
+ return x
153
+
154
+ class TransposedUpsample(nn.Module):
155
+ 'Learned 2x upsampling without padding'
156
+ def __init__(self, channels, out_channels=None, ks=5):
157
+ super().__init__()
158
+ self.channels = channels
159
+ self.out_channels = out_channels or channels
160
+
161
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
162
+
163
+ def forward(self,x):
164
+ return self.up(x)
165
+
166
+
167
+ class Downsample(nn.Module):
168
+ """
169
+ A downsampling layer with an optional convolution.
170
+ :param channels: channels in the inputs and outputs.
171
+ :param use_conv: a bool determining if a convolution is applied.
172
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
173
+ downsampling occurs in the inner-two dimensions.
174
+ """
175
+
176
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
177
+ super().__init__()
178
+ self.channels = channels
179
+ self.out_channels = out_channels or channels
180
+ self.use_conv = use_conv
181
+ self.dims = dims
182
+ stride = 2 if dims != 3 else (1, 2, 2)
183
+ if use_conv:
184
+ self.op = conv_nd(
185
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
186
+ )
187
+ else:
188
+ assert self.channels == self.out_channels
189
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
190
+
191
+ def forward(self, x):
192
+ assert x.shape[1] == self.channels
193
+ return self.op(x)
194
+
195
+
196
+ class ResBlock(TimestepBlock):
197
+ """
198
+ A residual block that can optionally change the number of channels.
199
+ :param channels: the number of input channels.
200
+ :param emb_channels: the number of timestep embedding channels.
201
+ :param dropout: the rate of dropout.
202
+ :param out_channels: if specified, the number of out channels.
203
+ :param use_conv: if True and out_channels is specified, use a spatial
204
+ convolution instead of a smaller 1x1 convolution to change the
205
+ channels in the skip connection.
206
+ :param dims: determines if the signal is 1D, 2D, or 3D.
207
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
208
+ :param up: if True, use this block for upsampling.
209
+ :param down: if True, use this block for downsampling.
210
+ """
211
+
212
+ def __init__(
213
+ self,
214
+ channels,
215
+ emb_channels,
216
+ dropout,
217
+ out_channels=None,
218
+ use_conv=False,
219
+ use_scale_shift_norm=False,
220
+ dims=2,
221
+ use_checkpoint=False,
222
+ up=False,
223
+ down=False,
224
+ ):
225
+ super().__init__()
226
+ self.channels = channels
227
+ self.emb_channels = emb_channels
228
+ self.dropout = dropout
229
+ self.out_channels = out_channels or channels
230
+ self.use_conv = use_conv
231
+ self.use_checkpoint = use_checkpoint
232
+ self.use_scale_shift_norm = use_scale_shift_norm
233
+
234
+ self.in_layers = nn.Sequential(
235
+ normalization(channels),
236
+ nn.SiLU(),
237
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
238
+ )
239
+
240
+ self.updown = up or down
241
+
242
+ if up:
243
+ self.h_upd = Upsample(channels, False, dims)
244
+ self.x_upd = Upsample(channels, False, dims)
245
+ elif down:
246
+ self.h_upd = Downsample(channels, False, dims)
247
+ self.x_upd = Downsample(channels, False, dims)
248
+ else:
249
+ self.h_upd = self.x_upd = nn.Identity()
250
+
251
+ self.emb_layers = nn.Sequential(
252
+ nn.SiLU(),
253
+ linear(
254
+ emb_channels,
255
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
256
+ ),
257
+ )
258
+ self.out_layers = nn.Sequential(
259
+ normalization(self.out_channels),
260
+ nn.SiLU(),
261
+ nn.Dropout(p=dropout),
262
+ zero_module(
263
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
264
+ ),
265
+ )
266
+
267
+ if self.out_channels == channels:
268
+ self.skip_connection = nn.Identity()
269
+ elif use_conv:
270
+ self.skip_connection = conv_nd(
271
+ dims, channels, self.out_channels, 3, padding=1
272
+ )
273
+ else:
274
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
275
+
276
+ def forward(self, x, emb):
277
+ """
278
+ Apply the block to a Tensor, conditioned on a timestep embedding.
279
+ :param x: an [N x C x ...] Tensor of features.
280
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
281
+ :return: an [N x C x ...] Tensor of outputs.
282
+ """
283
+ return checkpoint(
284
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
285
+ )
286
+
287
+
288
+ def _forward(self, x, emb):
289
+ if self.updown:
290
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
291
+ h = in_rest(x)
292
+ h = self.h_upd(h)
293
+ x = self.x_upd(x)
294
+ h = in_conv(h)
295
+ else:
296
+ h = self.in_layers(x)
297
+ emb_out = self.emb_layers(emb).type(h.dtype)
298
+ while len(emb_out.shape) < len(h.shape):
299
+ emb_out = emb_out[..., None]
300
+ if self.use_scale_shift_norm:
301
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
302
+ scale, shift = th.chunk(emb_out, 2, dim=1)
303
+ h = out_norm(h) * (1 + scale) + shift
304
+ h = out_rest(h)
305
+ else:
306
+ h = h + emb_out
307
+ h = self.out_layers(h)
308
+ return self.skip_connection(x) + h
309
+
310
+
311
+ class My_ResBlock(TimestepBlock):
312
+ """
313
+ A residual block that can optionally change the number of channels.
314
+ :param channels: the number of input channels.
315
+ :param emb_channels: the number of timestep embedding channels.
316
+ :param dropout: the rate of dropout.
317
+ :param out_channels: if specified, the number of out channels.
318
+ :param use_conv: if True and out_channels is specified, use a spatial
319
+ convolution instead of a smaller 1x1 convolution to change the
320
+ channels in the skip connection.
321
+ :param dims: determines if the signal is 1D, 2D, or 3D.
322
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
323
+ :param up: if True, use this block for upsampling.
324
+ :param down: if True, use this block for downsampling.
325
+ """
326
+
327
+ def __init__(
328
+ self,
329
+ channels,
330
+ emb_channels,
331
+ dropout,
332
+ out_channels=None,
333
+ use_conv=False,
334
+ use_scale_shift_norm=False,
335
+ dims=2,
336
+ use_checkpoint=False,
337
+ up=False,
338
+ down=False,
339
+ ):
340
+ super().__init__()
341
+ self.channels = channels
342
+ self.emb_channels = emb_channels
343
+ self.dropout = dropout
344
+ self.out_channels = out_channels or channels
345
+ self.use_conv = use_conv
346
+ self.use_checkpoint = use_checkpoint
347
+ self.use_scale_shift_norm = use_scale_shift_norm
348
+
349
+ self.in_layers = nn.Sequential(
350
+ normalization(channels),
351
+ nn.SiLU(),
352
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
353
+ )
354
+
355
+ self.updown = up or down
356
+
357
+ if up:
358
+ self.h_upd = Upsample(channels, False, dims)
359
+ self.x_upd = Upsample(channels, False, dims)
360
+ elif down:
361
+ self.h_upd = Downsample(channels, False, dims)
362
+ self.x_upd = Downsample(channels, False, dims)
363
+ else:
364
+ self.h_upd = self.x_upd = nn.Identity()
365
+
366
+ self.emb_layers = nn.Sequential(
367
+ nn.SiLU(),
368
+ linear(
369
+ emb_channels,
370
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
371
+ ),
372
+ )
373
+ self.out_layers = nn.Sequential(
374
+ normalization(self.out_channels),
375
+ nn.SiLU(),
376
+ nn.Dropout(p=dropout),
377
+ zero_module(
378
+ conv_nd(dims, self.out_channels, 4, 3, padding=1)
379
+ ),
380
+ )
381
+
382
+ if self.out_channels == channels:
383
+ self.skip_connection = nn.Identity()
384
+ elif use_conv:
385
+ self.skip_connection = conv_nd(
386
+ dims, channels, self.out_channels, 3, padding=1
387
+ )
388
+ else:
389
+ self.skip_connection = conv_nd(dims, channels, 4, 1)
390
+
391
+ def forward(self, x, emb):
392
+ """
393
+ Apply the block to a Tensor, conditioned on a timestep embedding.
394
+ :param x: an [N x C x ...] Tensor of features.
395
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
396
+ :return: an [N x C x ...] Tensor of outputs.
397
+ """
398
+ return checkpoint(
399
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
400
+ )
401
+
402
+
403
+ def _forward(self, x, emb):
404
+ if self.updown:
405
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
406
+ h = in_rest(x)
407
+ h = self.h_upd(h)
408
+ x = self.x_upd(x)
409
+ h = in_conv(h)
410
+ else:
411
+ h = self.in_layers(x)
412
+ emb_out = self.emb_layers(emb).type(h.dtype)
413
+ while len(emb_out.shape) < len(h.shape):
414
+ emb_out = emb_out[..., None]
415
+ if self.use_scale_shift_norm:
416
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
417
+ scale, shift = th.chunk(emb_out, 2, dim=1)
418
+ h = out_norm(h) * (1 + scale) + shift
419
+ h = out_rest(h)
420
+ else:
421
+ h = h + emb_out
422
+ h = self.out_layers(h)
423
+ return h
424
+
425
+
426
+ class AttentionBlock(nn.Module):
427
+ """
428
+ An attention block that allows spatial positions to attend to each other.
429
+ Originally ported from here, but adapted to the N-d case.
430
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
431
+ """
432
+
433
+ def __init__(
434
+ self,
435
+ channels,
436
+ num_heads=1,
437
+ num_head_channels=-1,
438
+ use_checkpoint=False,
439
+ use_new_attention_order=False,
440
+ ):
441
+ super().__init__()
442
+ self.channels = channels
443
+ if num_head_channels == -1:
444
+ self.num_heads = num_heads
445
+ else:
446
+ assert (
447
+ channels % num_head_channels == 0
448
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
449
+ self.num_heads = channels // num_head_channels
450
+ self.use_checkpoint = use_checkpoint
451
+ self.norm = normalization(channels)
452
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
453
+ if use_new_attention_order:
454
+ # split qkv before split heads
455
+ self.attention = QKVAttention(self.num_heads)
456
+ else:
457
+ # split heads before split qkv
458
+ self.attention = QKVAttentionLegacy(self.num_heads)
459
+
460
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
461
+
462
+ def forward(self, x):
463
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
464
+ #return pt_checkpoint(self._forward, x) # pytorch
465
+
466
+ def _forward(self, x):
467
+ b, c, *spatial = x.shape
468
+ x = x.reshape(b, c, -1)
469
+ qkv = self.qkv(self.norm(x))
470
+ h = self.attention(qkv)
471
+ h = self.proj_out(h)
472
+ return (x + h).reshape(b, c, *spatial)
473
+
474
+
475
+ def count_flops_attn(model, _x, y):
476
+ """
477
+ A counter for the `thop` package to count the operations in an
478
+ attention operation.
479
+ Meant to be used like:
480
+ macs, params = thop.profile(
481
+ model,
482
+ inputs=(inputs, timestamps),
483
+ custom_ops={QKVAttention: QKVAttention.count_flops},
484
+ )
485
+ """
486
+ b, c, *spatial = y[0].shape
487
+ num_spatial = int(np.prod(spatial))
488
+ # We perform two matmuls with the same number of ops.
489
+ # The first computes the weight matrix, the second computes
490
+ # the combination of the value vectors.
491
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
492
+ model.total_ops += th.DoubleTensor([matmul_ops])
493
+
494
+
495
+ class QKVAttentionLegacy(nn.Module):
496
+ """
497
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
498
+ """
499
+
500
+ def __init__(self, n_heads):
501
+ super().__init__()
502
+ self.n_heads = n_heads
503
+
504
+ def forward(self, qkv):
505
+ """
506
+ Apply QKV attention.
507
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
508
+ :return: an [N x (H * C) x T] tensor after attention.
509
+ """
510
+ bs, width, length = qkv.shape
511
+ assert width % (3 * self.n_heads) == 0
512
+ ch = width // (3 * self.n_heads)
513
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
514
+ scale = 1 / math.sqrt(math.sqrt(ch))
515
+ weight = th.einsum(
516
+ "bct,bcs->bts", q * scale, k * scale
517
+ ) # More stable with f16 than dividing afterwards
518
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
519
+ a = th.einsum("bts,bcs->bct", weight, v)
520
+ return a.reshape(bs, -1, length)
521
+
522
+ @staticmethod
523
+ def count_flops(model, _x, y):
524
+ return count_flops_attn(model, _x, y)
525
+
526
+
527
+ class QKVAttention(nn.Module):
528
+ """
529
+ A module which performs QKV attention and splits in a different order.
530
+ """
531
+
532
+ def __init__(self, n_heads):
533
+ super().__init__()
534
+ self.n_heads = n_heads
535
+
536
+ def forward(self, qkv):
537
+ """
538
+ Apply QKV attention.
539
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
540
+ :return: an [N x (H * C) x T] tensor after attention.
541
+ """
542
+ bs, width, length = qkv.shape
543
+ assert width % (3 * self.n_heads) == 0
544
+ ch = width // (3 * self.n_heads)
545
+ q, k, v = qkv.chunk(3, dim=1)
546
+ scale = 1 / math.sqrt(math.sqrt(ch))
547
+ weight = th.einsum(
548
+ "bct,bcs->bts",
549
+ (q * scale).view(bs * self.n_heads, ch, length),
550
+ (k * scale).view(bs * self.n_heads, ch, length),
551
+ ) # More stable with f16 than dividing afterwards
552
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
553
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
554
+ return a.reshape(bs, -1, length)
555
+
556
+ @staticmethod
557
+ def count_flops(model, _x, y):
558
+ return count_flops_attn(model, _x, y)
559
+
560
+
561
+ class UNetModel(nn.Module):
562
+ """
563
+ The full UNet model with attention and timestep embedding.
564
+ :param in_channels: channels in the input Tensor.
565
+ :param model_channels: base channel count for the model.
566
+ :param out_channels: channels in the output Tensor.
567
+ :param num_res_blocks: number of residual blocks per downsample.
568
+ :param attention_resolutions: a collection of downsample rates at which
569
+ attention will take place. May be a set, list, or tuple.
570
+ For example, if this contains 4, then at 4x downsampling, attention
571
+ will be used.
572
+ :param dropout: the dropout probability.
573
+ :param channel_mult: channel multiplier for each level of the UNet.
574
+ :param conv_resample: if True, use learned convolutions for upsampling and
575
+ downsampling.
576
+ :param dims: determines if the signal is 1D, 2D, or 3D.
577
+ :param num_classes: if specified (as an int), then this model will be
578
+ class-conditional with `num_classes` classes.
579
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
580
+ :param num_heads: the number of attention heads in each attention layer.
581
+ :param num_heads_channels: if specified, ignore num_heads and instead use
582
+ a fixed channel width per attention head.
583
+ :param num_heads_upsample: works with num_heads to set a different number
584
+ of heads for upsampling. Deprecated.
585
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
586
+ :param resblock_updown: use residual blocks for up/downsampling.
587
+ :param use_new_attention_order: use a different attention pattern for potentially
588
+ increased efficiency.
589
+ """
590
+
591
+ def __init__(
592
+ self,
593
+ image_size,
594
+ in_channels,
595
+ model_channels,
596
+ out_channels,
597
+ num_res_blocks,
598
+ attention_resolutions,
599
+ dropout=0,
600
+ channel_mult=(1, 2, 4, 8),
601
+ conv_resample=True,
602
+ dims=2,
603
+ num_classes=None,
604
+ use_checkpoint=False,
605
+ use_fp16=False,
606
+ num_heads=-1,
607
+ num_head_channels=-1,
608
+ num_heads_upsample=-1,
609
+ use_scale_shift_norm=False,
610
+ resblock_updown=False,
611
+ use_new_attention_order=False,
612
+ use_spatial_transformer=False, # custom transformer support
613
+ transformer_depth=1, # custom transformer support
614
+ context_dim=None, # custom transformer support
615
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
616
+ legacy=True,
617
+ add_conv_in_front_of_unet=False,
618
+ ):
619
+ super().__init__()
620
+ if use_spatial_transformer:
621
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
622
+
623
+ if context_dim is not None:
624
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
625
+ from omegaconf.listconfig import ListConfig
626
+ if type(context_dim) == ListConfig:
627
+ context_dim = list(context_dim)
628
+
629
+ if num_heads_upsample == -1:
630
+ num_heads_upsample = num_heads
631
+
632
+ if num_heads == -1:
633
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
634
+
635
+ if num_head_channels == -1:
636
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
637
+
638
+ self.image_size = image_size
639
+ self.in_channels = in_channels
640
+ self.model_channels = model_channels
641
+ self.out_channels = out_channels
642
+ self.num_res_blocks = num_res_blocks
643
+ self.attention_resolutions = attention_resolutions
644
+ self.dropout = dropout
645
+ self.channel_mult = channel_mult
646
+ self.conv_resample = conv_resample
647
+ self.num_classes = num_classes
648
+ self.use_checkpoint = use_checkpoint
649
+ self.dtype = th.float16 if use_fp16 else th.float32
650
+ self.num_heads = num_heads
651
+ self.num_head_channels = num_head_channels
652
+ self.num_heads_upsample = num_heads_upsample
653
+ self.predict_codebook_ids = n_embed is not None
654
+ self.add_conv_in_front_of_unet=add_conv_in_front_of_unet
655
+
656
+
657
+ # save contexts
658
+ self.save_contexts = False
659
+ self.use_contexts = False
660
+ self.contexts = deque([])
661
+
662
+ time_embed_dim = model_channels * 4
663
+ self.time_embed = nn.Sequential(
664
+ linear(model_channels, time_embed_dim),
665
+ nn.SiLU(),
666
+ linear(time_embed_dim, time_embed_dim),
667
+ )
668
+
669
+ if self.num_classes is not None:
670
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
671
+
672
+
673
+ if self.add_conv_in_front_of_unet:
674
+ self.add_resbolck = nn.ModuleList(
675
+ [
676
+ TimestepEmbedSequential(
677
+ conv_nd(dims, 9, model_channels, 3, padding=1)
678
+ )
679
+ ]
680
+ )
681
+
682
+ add_layers = [
683
+ My_ResBlock(
684
+ model_channels,
685
+ time_embed_dim,
686
+ dropout,
687
+ out_channels=model_channels,
688
+ dims=dims,
689
+ use_checkpoint=use_checkpoint,
690
+ use_scale_shift_norm=use_scale_shift_norm,
691
+ )
692
+ ]
693
+
694
+ self.add_resbolck.append(TimestepEmbedSequential(*add_layers))
695
+
696
+
697
+ self.input_blocks = nn.ModuleList(
698
+ [
699
+ TimestepEmbedSequential(
700
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
701
+ )
702
+ ]
703
+ )
704
+ self._feature_size = model_channels
705
+ input_block_chans = [model_channels]
706
+ ch = model_channels
707
+ ds = 1
708
+ for level, mult in enumerate(channel_mult):
709
+ for _ in range(num_res_blocks):
710
+ layers = [
711
+ ResBlock(
712
+ ch,
713
+ time_embed_dim,
714
+ dropout,
715
+ out_channels=mult * model_channels,
716
+ dims=dims,
717
+ use_checkpoint=use_checkpoint,
718
+ use_scale_shift_norm=use_scale_shift_norm,
719
+ )
720
+ ]
721
+ ch = mult * model_channels
722
+ if ds in attention_resolutions:
723
+ if num_head_channels == -1:
724
+ dim_head = ch // num_heads
725
+ else:
726
+ num_heads = ch // num_head_channels
727
+ dim_head = num_head_channels
728
+ if legacy:
729
+ #num_heads = 1
730
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
731
+ layers.append(
732
+ AttentionBlock(
733
+ ch,
734
+ use_checkpoint=use_checkpoint,
735
+ num_heads=num_heads,
736
+ num_head_channels=dim_head,
737
+ use_new_attention_order=use_new_attention_order,
738
+ ) if not use_spatial_transformer else SpatialTransformer(
739
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
740
+ )
741
+ )
742
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
743
+ self._feature_size += ch
744
+ input_block_chans.append(ch)
745
+ if level != len(channel_mult) - 1:
746
+ out_ch = ch
747
+ self.input_blocks.append(
748
+ TimestepEmbedSequential(
749
+ ResBlock(
750
+ ch,
751
+ time_embed_dim,
752
+ dropout,
753
+ out_channels=out_ch,
754
+ dims=dims,
755
+ use_checkpoint=use_checkpoint,
756
+ use_scale_shift_norm=use_scale_shift_norm,
757
+ down=True,
758
+ )
759
+ if resblock_updown
760
+ else Downsample(
761
+ ch, conv_resample, dims=dims, out_channels=out_ch
762
+ )
763
+ )
764
+ )
765
+ ch = out_ch
766
+ input_block_chans.append(ch)
767
+ ds *= 2
768
+ self._feature_size += ch
769
+
770
+ if num_head_channels == -1:
771
+ dim_head = ch // num_heads
772
+ else:
773
+ num_heads = ch // num_head_channels
774
+ dim_head = num_head_channels
775
+ if legacy:
776
+ #num_heads = 1
777
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
778
+ self.middle_block = TimestepEmbedSequential(
779
+ ResBlock(
780
+ ch,
781
+ time_embed_dim,
782
+ dropout,
783
+ dims=dims,
784
+ use_checkpoint=use_checkpoint,
785
+ use_scale_shift_norm=use_scale_shift_norm,
786
+ ),
787
+ AttentionBlock(
788
+ ch,
789
+ use_checkpoint=use_checkpoint,
790
+ num_heads=num_heads,
791
+ num_head_channels=dim_head,
792
+ use_new_attention_order=use_new_attention_order,
793
+ ) if not use_spatial_transformer else SpatialTransformer(
794
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
795
+ ),
796
+ ResBlock(
797
+ ch,
798
+ time_embed_dim,
799
+ dropout,
800
+ dims=dims,
801
+ use_checkpoint=use_checkpoint,
802
+ use_scale_shift_norm=use_scale_shift_norm,
803
+ ),
804
+ )
805
+ self._feature_size += ch
806
+
807
+ self.output_blocks = nn.ModuleList([])
808
+ for level, mult in list(enumerate(channel_mult))[::-1]:
809
+ for i in range(num_res_blocks + 1):
810
+ ich = input_block_chans.pop()
811
+ layers = [
812
+ ResBlock(
813
+ ch + ich,
814
+ time_embed_dim,
815
+ dropout,
816
+ out_channels=model_channels * mult,
817
+ dims=dims,
818
+ use_checkpoint=use_checkpoint,
819
+ use_scale_shift_norm=use_scale_shift_norm,
820
+ )
821
+ ]
822
+ ch = model_channels * mult
823
+ if ds in attention_resolutions:
824
+ if num_head_channels == -1:
825
+ dim_head = ch // num_heads
826
+ else:
827
+ num_heads = ch // num_head_channels
828
+ dim_head = num_head_channels
829
+ if legacy:
830
+ #num_heads = 1
831
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
832
+ layers.append(
833
+ AttentionBlock(
834
+ ch,
835
+ use_checkpoint=use_checkpoint,
836
+ num_heads=num_heads_upsample,
837
+ num_head_channels=dim_head,
838
+ use_new_attention_order=use_new_attention_order,
839
+ ) if not use_spatial_transformer else SpatialTransformer(
840
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
841
+ )
842
+ )
843
+ if level and i == num_res_blocks:
844
+ out_ch = ch
845
+ layers.append(
846
+ ResBlock(
847
+ ch,
848
+ time_embed_dim,
849
+ dropout,
850
+ out_channels=out_ch,
851
+ dims=dims,
852
+ use_checkpoint=use_checkpoint,
853
+ use_scale_shift_norm=use_scale_shift_norm,
854
+ up=True,
855
+ )
856
+ if resblock_updown
857
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
858
+ )
859
+ ds //= 2
860
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
861
+ self._feature_size += ch
862
+
863
+ self.out = nn.Sequential(
864
+ normalization(ch),
865
+ nn.SiLU(),
866
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
867
+ )
868
+ if self.predict_codebook_ids:
869
+ self.id_predictor = nn.Sequential(
870
+ normalization(ch),
871
+ conv_nd(dims, model_channels, n_embed, 1),
872
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
873
+ )
874
+
875
+ def convert_to_fp16(self):
876
+ """
877
+ Convert the torso of the model to float16.
878
+ """
879
+ self.input_blocks.apply(convert_module_to_f16)
880
+ self.middle_block.apply(convert_module_to_f16)
881
+ self.output_blocks.apply(convert_module_to_f16)
882
+
883
+ def convert_to_fp32(self):
884
+ """
885
+ Convert the torso of the model to float32.
886
+ """
887
+ self.input_blocks.apply(convert_module_to_f32)
888
+ self.middle_block.apply(convert_module_to_f32)
889
+ self.output_blocks.apply(convert_module_to_f32)
890
+
891
+ def forward(self, x, timesteps=None, context=None, y=None, get_contexts=False, passed_contexts=None, corresp=None,**kwargs):
892
+ """
893
+ Apply the model to an input batch.
894
+ :param x: an [N x C x ...] Tensor of inputs.
895
+ :param timesteps: a 1-D batch of timesteps.
896
+ :param context: conditioning plugged in via crossattn
897
+ :param y: an [N] Tensor of labels, if class-conditional.
898
+ :return: an [N x C x ...] Tensor of outputs.
899
+ """
900
+ assert (y is not None) == (
901
+ self.num_classes is not None
902
+ ), "must specify y if and only if the model is class-conditional"
903
+ hs = []
904
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
905
+ emb = self.time_embed(t_emb)
906
+
907
+ ds = [8, 16, 32, 64]
908
+
909
+ # cur_step = len(glob.glob('/dev/shm/dumpster/steps/*'))
910
+ # os.makedirs(f'/dev/shm/dumpster/steps/{cur_step:04d}', exist_ok=False)
911
+
912
+ og_mask = x[:, -1:] # Bx1x64x64
913
+ batch_size = og_mask.shape[0]
914
+ masks = dict()
915
+
916
+ for d in ds:
917
+ resized_mask = torchvision.transforms.functional.resize(og_mask, size=(d, d))
918
+
919
+ mask = resized_mask.reshape(batch_size, -1)
920
+ masks[d] = mask
921
+
922
+ # if self.use_contexts:
923
+ # passed_contexts = self.contexts.popleft()
924
+
925
+ all_kqvs = []
926
+
927
+ if self.num_classes is not None:
928
+ assert y.shape == (x.shape[0],)
929
+ emb = emb + self.label_emb(y)
930
+
931
+ h = x.type(self.dtype)
932
+
933
+ if self.add_conv_in_front_of_unet:
934
+ for module in self.add_resbolck:
935
+ h, kqv = module(h, emb, context, passed_contexts, len(all_kqvs), masks=masks, corresp=corresp)
936
+ all_kqvs.append(kqv)
937
+
938
+ for module in self.input_blocks:
939
+ h, kqv = module(h, emb, context, passed_contexts, len(all_kqvs), masks=masks, corresp=corresp)
940
+ hs.append(h)
941
+ all_kqvs.append(kqv)
942
+
943
+ h, kqv = self.middle_block(h, emb, context, passed_contexts, len(all_kqvs), masks=masks, corresp=corresp)
944
+ all_kqvs.append(kqv)
945
+ for module in self.output_blocks:
946
+ h = th.cat([h, hs.pop()], dim=1)
947
+ h, kqv = module(h, emb, context, passed_contexts, len(all_kqvs), masks=masks, corresp=corresp)
948
+ all_kqvs.append(kqv)
949
+
950
+ h = h.type(x.dtype)
951
+
952
+ # print(all_kqvs)
953
+ # for i in range(len(all_kqvs)):
954
+ # print('len of contexts at ', i, 'is ', len(all_kqvs[i]))
955
+ # for j in range(len(all_kqvs[i])):
956
+ # print('len of contexts at ', i, j, 'is ', len(all_kqvs[i][j]))
957
+ # for k in range(len(all_kqvs[i][j])):
958
+ # print(all_kqvs[i][j][k])
959
+
960
+
961
+
962
+ if self.predict_codebook_ids:
963
+ out = self.id_predictor(h)
964
+ else:
965
+ out = self.out(h)
966
+
967
+ if self.save_contexts:
968
+ self.contexts.append(all_kqvs)
969
+
970
+ if get_contexts:
971
+ return out, all_kqvs
972
+ else:
973
+ return out
974
+
975
+ def get_contexts(self, x, timesteps=None, context=None, y=None,**kwargs):
976
+ """
977
+ same as forward but saves self attention contexts
978
+ """
979
+ assert (y is not None) == (
980
+ self.num_classes is not None
981
+ ), "must specify y if and only if the model is class-conditional"
982
+ hs = []
983
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
984
+ emb = self.time_embed(t_emb)
985
+
986
+ if self.num_classes is not None:
987
+ assert y.shape == (x.shape[0],)
988
+ emb = emb + self.label_emb(y)
989
+
990
+ h = x.type(self.dtype)
991
+
992
+ if self.add_conv_in_front_of_unet:
993
+ for module in self.add_resbolck:
994
+ h = module(h, emb, context)
995
+
996
+ for module in self.input_blocks:
997
+ h = module(h, emb, context)
998
+ hs.append(h)
999
+ h = self.middle_block(h, emb, context)
1000
+ for module in self.output_blocks:
1001
+ h = th.cat([h, hs.pop()], dim=1)
1002
+ h = module(h, emb, context)
1003
+ h = h.type(x.dtype)
1004
+ if self.predict_codebook_ids:
1005
+ return self.id_predictor(h)
1006
+ else:
1007
+ return self.out(h)
1008
+
1009
+ class EncoderUNetModel(nn.Module):
1010
+ """
1011
+ The half UNet model with attention and timestep embedding.
1012
+ For usage, see UNet.
1013
+ """
1014
+
1015
+ def __init__(
1016
+ self,
1017
+ image_size,
1018
+ in_channels,
1019
+ model_channels,
1020
+ out_channels,
1021
+ num_res_blocks,
1022
+ attention_resolutions,
1023
+ dropout=0,
1024
+ channel_mult=(1, 2, 4, 8),
1025
+ conv_resample=True,
1026
+ dims=2,
1027
+ use_checkpoint=False,
1028
+ use_fp16=False,
1029
+ num_heads=1,
1030
+ num_head_channels=-1,
1031
+ num_heads_upsample=-1,
1032
+ use_scale_shift_norm=False,
1033
+ resblock_updown=False,
1034
+ use_new_attention_order=False,
1035
+ pool="adaptive",
1036
+ *args,
1037
+ **kwargs
1038
+ ):
1039
+ super().__init__()
1040
+
1041
+ if num_heads_upsample == -1:
1042
+ num_heads_upsample = num_heads
1043
+
1044
+ self.in_channels = in_channels
1045
+ self.model_channels = model_channels
1046
+ self.out_channels = out_channels
1047
+ self.num_res_blocks = num_res_blocks
1048
+ self.attention_resolutions = attention_resolutions
1049
+ self.dropout = dropout
1050
+ self.channel_mult = channel_mult
1051
+ self.conv_resample = conv_resample
1052
+ self.use_checkpoint = use_checkpoint
1053
+ self.dtype = th.float16 if use_fp16 else th.float32
1054
+ self.num_heads = num_heads
1055
+ self.num_head_channels = num_head_channels
1056
+ self.num_heads_upsample = num_heads_upsample
1057
+
1058
+ time_embed_dim = model_channels * 4
1059
+ self.time_embed = nn.Sequential(
1060
+ linear(model_channels, time_embed_dim),
1061
+ nn.SiLU(),
1062
+ linear(time_embed_dim, time_embed_dim),
1063
+ )
1064
+
1065
+ self.input_blocks = nn.ModuleList(
1066
+ [
1067
+ TimestepEmbedSequential(
1068
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
1069
+ )
1070
+ ]
1071
+ )
1072
+ self._feature_size = model_channels
1073
+ input_block_chans = [model_channels]
1074
+ ch = model_channels
1075
+ ds = 1
1076
+ for level, mult in enumerate(channel_mult):
1077
+ for _ in range(num_res_blocks):
1078
+ layers = [
1079
+ ResBlock(
1080
+ ch,
1081
+ time_embed_dim,
1082
+ dropout,
1083
+ out_channels=mult * model_channels,
1084
+ dims=dims,
1085
+ use_checkpoint=use_checkpoint,
1086
+ use_scale_shift_norm=use_scale_shift_norm,
1087
+ )
1088
+ ]
1089
+ ch = mult * model_channels
1090
+ if ds in attention_resolutions:
1091
+ layers.append(
1092
+ AttentionBlock(
1093
+ ch,
1094
+ use_checkpoint=use_checkpoint,
1095
+ num_heads=num_heads,
1096
+ num_head_channels=num_head_channels,
1097
+ use_new_attention_order=use_new_attention_order,
1098
+ )
1099
+ )
1100
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
1101
+ self._feature_size += ch
1102
+ input_block_chans.append(ch)
1103
+ if level != len(channel_mult) - 1:
1104
+ out_ch = ch
1105
+ self.input_blocks.append(
1106
+ TimestepEmbedSequential(
1107
+ ResBlock(
1108
+ ch,
1109
+ time_embed_dim,
1110
+ dropout,
1111
+ out_channels=out_ch,
1112
+ dims=dims,
1113
+ use_checkpoint=use_checkpoint,
1114
+ use_scale_shift_norm=use_scale_shift_norm,
1115
+ down=True,
1116
+ )
1117
+ if resblock_updown
1118
+ else Downsample(
1119
+ ch, conv_resample, dims=dims, out_channels=out_ch
1120
+ )
1121
+ )
1122
+ )
1123
+ ch = out_ch
1124
+ input_block_chans.append(ch)
1125
+ ds *= 2
1126
+ self._feature_size += ch
1127
+
1128
+ self.middle_block = TimestepEmbedSequential(
1129
+ ResBlock(
1130
+ ch,
1131
+ time_embed_dim,
1132
+ dropout,
1133
+ dims=dims,
1134
+ use_checkpoint=use_checkpoint,
1135
+ use_scale_shift_norm=use_scale_shift_norm,
1136
+ ),
1137
+ AttentionBlock(
1138
+ ch,
1139
+ use_checkpoint=use_checkpoint,
1140
+ num_heads=num_heads,
1141
+ num_head_channels=num_head_channels,
1142
+ use_new_attention_order=use_new_attention_order,
1143
+ ),
1144
+ ResBlock(
1145
+ ch,
1146
+ time_embed_dim,
1147
+ dropout,
1148
+ dims=dims,
1149
+ use_checkpoint=use_checkpoint,
1150
+ use_scale_shift_norm=use_scale_shift_norm,
1151
+ ),
1152
+ )
1153
+ self._feature_size += ch
1154
+ self.pool = pool
1155
+ if pool == "adaptive":
1156
+ self.out = nn.Sequential(
1157
+ normalization(ch),
1158
+ nn.SiLU(),
1159
+ nn.AdaptiveAvgPool2d((1, 1)),
1160
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
1161
+ nn.Flatten(),
1162
+ )
1163
+ elif pool == "attention":
1164
+ assert num_head_channels != -1
1165
+ self.out = nn.Sequential(
1166
+ normalization(ch),
1167
+ nn.SiLU(),
1168
+ AttentionPool2d(
1169
+ (image_size // ds), ch, num_head_channels, out_channels
1170
+ ),
1171
+ )
1172
+ elif pool == "spatial":
1173
+ self.out = nn.Sequential(
1174
+ nn.Linear(self._feature_size, 2048),
1175
+ nn.ReLU(),
1176
+ nn.Linear(2048, self.out_channels),
1177
+ )
1178
+ elif pool == "spatial_v2":
1179
+ self.out = nn.Sequential(
1180
+ nn.Linear(self._feature_size, 2048),
1181
+ normalization(2048),
1182
+ nn.SiLU(),
1183
+ nn.Linear(2048, self.out_channels),
1184
+ )
1185
+ else:
1186
+ raise NotImplementedError(f"Unexpected {pool} pooling")
1187
+
1188
+ def convert_to_fp16(self):
1189
+ """
1190
+ Convert the torso of the model to float16.
1191
+ """
1192
+ self.input_blocks.apply(convert_module_to_f16)
1193
+ self.middle_block.apply(convert_module_to_f16)
1194
+
1195
+ def convert_to_fp32(self):
1196
+ """
1197
+ Convert the torso of the model to float32.
1198
+ """
1199
+ self.input_blocks.apply(convert_module_to_f32)
1200
+ self.middle_block.apply(convert_module_to_f32)
1201
+
1202
+ def forward(self, x, timesteps):
1203
+ """
1204
+ Apply the model to an input batch.
1205
+ :param x: an [N x C x ...] Tensor of inputs.
1206
+ :param timesteps: a 1-D batch of timesteps.
1207
+ :return: an [N x K] Tensor of outputs.
1208
+ """
1209
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
1210
+
1211
+ results = []
1212
+ h = x.type(self.dtype)
1213
+ for module in self.input_blocks:
1214
+ h = module(h, emb)
1215
+ if self.pool.startswith("spatial"):
1216
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1217
+ h = self.middle_block(h, emb)
1218
+ if self.pool.startswith("spatial"):
1219
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1220
+ h = th.cat(results, axis=-1)
1221
+ return self.out(h)
1222
+ else:
1223
+ h = h.type(x.dtype)
1224
+ return self.out(h)
1225
+
ldm/modules/diffusionmodules/util.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ # adopted from
15
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
16
+ # and
17
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
18
+ # and
19
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
20
+ #
21
+ # thanks!
22
+
23
+
24
+ import os
25
+ import math
26
+ import torch
27
+ import torch.nn as nn
28
+ import numpy as np
29
+ from einops import repeat
30
+
31
+ from ldm.util import instantiate_from_config
32
+
33
+
34
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
35
+ if schedule == "linear":
36
+ betas = (
37
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
38
+ )
39
+
40
+ elif schedule == "cosine":
41
+ timesteps = (
42
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
43
+ )
44
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
45
+ alphas = torch.cos(alphas).pow(2)
46
+ alphas = alphas / alphas[0]
47
+ betas = 1 - alphas[1:] / alphas[:-1]
48
+ betas = np.clip(betas, a_min=0, a_max=0.999)
49
+
50
+ elif schedule == "sqrt_linear":
51
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
52
+ elif schedule == "sqrt":
53
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
54
+ else:
55
+ raise ValueError(f"schedule '{schedule}' unknown.")
56
+ return betas.numpy()
57
+
58
+
59
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True, steps=None):
60
+ if ddim_discr_method == 'uniform':
61
+ c = num_ddpm_timesteps // num_ddim_timesteps
62
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
63
+ elif ddim_discr_method == 'quad':
64
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
65
+ elif ddim_discr_method == 'manual':
66
+ assert steps is not None
67
+ ddim_timesteps = np.asarray(steps)
68
+ else:
69
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
70
+
71
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
72
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
73
+ steps_out = ddim_timesteps + 1
74
+ if verbose:
75
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
76
+ return steps_out
77
+
78
+
79
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
80
+ # select alphas for computing the variance schedule
81
+ alphas = alphacums[ddim_timesteps]
82
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
83
+
84
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
85
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
86
+ if verbose:
87
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
88
+ print(f'For the chosen value of eta, which is {eta}, '
89
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
90
+ return sigmas, alphas, alphas_prev
91
+
92
+
93
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
94
+ """
95
+ Create a beta schedule that discretizes the given alpha_t_bar function,
96
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
97
+ :param num_diffusion_timesteps: the number of betas to produce.
98
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
99
+ produces the cumulative product of (1-beta) up to that
100
+ part of the diffusion process.
101
+ :param max_beta: the maximum beta to use; use values lower than 1 to
102
+ prevent singularities.
103
+ """
104
+ betas = []
105
+ for i in range(num_diffusion_timesteps):
106
+ t1 = i / num_diffusion_timesteps
107
+ t2 = (i + 1) / num_diffusion_timesteps
108
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
109
+ return np.array(betas)
110
+
111
+
112
+ def extract_into_tensor(a, t, x_shape):
113
+ b, *_ = t.shape
114
+ out = a.gather(-1, t)
115
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
116
+
117
+
118
+ def checkpoint(func, inputs, params, flag):
119
+ """
120
+ Evaluate a function without caching intermediate activations, allowing for
121
+ reduced memory at the expense of extra compute in the backward pass.
122
+ :param func: the function to evaluate.
123
+ :param inputs: the argument sequence to pass to `func`.
124
+ :param params: a sequence of parameters `func` depends on but does not
125
+ explicitly take as arguments.
126
+ :param flag: if False, disable gradient checkpointing.
127
+ """
128
+ if flag:
129
+ args = tuple(inputs) + tuple(params)
130
+ return CheckpointFunction.apply(func, len(inputs), *args)
131
+ else:
132
+ return func(*inputs)
133
+
134
+
135
+ class CheckpointFunction(torch.autograd.Function):
136
+ @staticmethod
137
+ # @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) # added this for map
138
+ def forward(ctx, run_function, length, *args):
139
+ ctx.run_function = run_function
140
+ ctx.input_tensors = list(args[:length])
141
+ ctx.input_params = list(args[length:])
142
+
143
+ with torch.no_grad():
144
+ output_tensors = ctx.run_function(*ctx.input_tensors)
145
+ return output_tensors
146
+
147
+ @staticmethod
148
+ # @torch.cuda.amp.custom_bwd # added this for map
149
+ def backward(ctx, *output_grads):
150
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
151
+ with torch.enable_grad():
152
+ # Fixes a bug where the first op in run_function modifies the
153
+ # Tensor storage in place, which is not allowed for detach()'d
154
+ # Tensors.
155
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
156
+ output_tensors = ctx.run_function(*shallow_copies)
157
+ input_grads = torch.autograd.grad(
158
+ output_tensors,
159
+ ctx.input_tensors + ctx.input_params,
160
+ output_grads,
161
+ allow_unused=True,
162
+ )
163
+ del ctx.input_tensors
164
+ del ctx.input_params
165
+ del output_tensors
166
+ return (None, None) + input_grads
167
+
168
+
169
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
170
+ """
171
+ Create sinusoidal timestep embeddings.
172
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
173
+ These may be fractional.
174
+ :param dim: the dimension of the output.
175
+ :param max_period: controls the minimum frequency of the embeddings.
176
+ :return: an [N x dim] Tensor of positional embeddings.
177
+ """
178
+ if not repeat_only:
179
+ half = dim // 2
180
+ freqs = torch.exp(
181
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
182
+ ).to(device=timesteps.device)
183
+ args = timesteps[:, None].float() * freqs[None]
184
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
185
+ if dim % 2:
186
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
187
+ else:
188
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
189
+ return embedding
190
+
191
+
192
+ def zero_module(module):
193
+ """
194
+ Zero out the parameters of a module and return it.
195
+ """
196
+ for p in module.parameters():
197
+ p.detach().zero_()
198
+ return module
199
+
200
+
201
+ def scale_module(module, scale):
202
+ """
203
+ Scale the parameters of a module and return it.
204
+ """
205
+ for p in module.parameters():
206
+ p.detach().mul_(scale)
207
+ return module
208
+
209
+
210
+ def mean_flat(tensor):
211
+ """
212
+ Take the mean over all non-batch dimensions.
213
+ """
214
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
215
+
216
+
217
+ def normalization(channels):
218
+ """
219
+ Make a standard normalization layer.
220
+ :param channels: number of input channels.
221
+ :return: an nn.Module for normalization.
222
+ """
223
+ return GroupNorm32(32, channels)
224
+
225
+
226
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
227
+ class SiLU(nn.Module):
228
+ def forward(self, x):
229
+ return x * torch.sigmoid(x)
230
+
231
+
232
+ class GroupNorm32(nn.GroupNorm):
233
+ def forward(self, x):
234
+ return super().forward(x.float()).type(x.dtype)
235
+
236
+ def conv_nd(dims, *args, **kwargs):
237
+ """
238
+ Create a 1D, 2D, or 3D convolution module.
239
+ """
240
+ if dims == 1:
241
+ return nn.Conv1d(*args, **kwargs)
242
+ elif dims == 2:
243
+ return nn.Conv2d(*args, **kwargs)
244
+ elif dims == 3:
245
+ return nn.Conv3d(*args, **kwargs)
246
+ raise ValueError(f"unsupported dimensions: {dims}")
247
+
248
+
249
+ def linear(*args, **kwargs):
250
+ """
251
+ Create a linear module.
252
+ """
253
+ return nn.Linear(*args, **kwargs)
254
+
255
+
256
+ def avg_pool_nd(dims, *args, **kwargs):
257
+ """
258
+ Create a 1D, 2D, or 3D average pooling module.
259
+ """
260
+ if dims == 1:
261
+ return nn.AvgPool1d(*args, **kwargs)
262
+ elif dims == 2:
263
+ return nn.AvgPool2d(*args, **kwargs)
264
+ elif dims == 3:
265
+ return nn.AvgPool3d(*args, **kwargs)
266
+ raise ValueError(f"unsupported dimensions: {dims}")
267
+
268
+
269
+ class HybridConditioner(nn.Module):
270
+
271
+ def __init__(self, c_concat_config, c_crossattn_config):
272
+ super().__init__()
273
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
274
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
275
+
276
+ def forward(self, c_concat, c_crossattn):
277
+ c_concat = self.concat_conditioner(c_concat)
278
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
279
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
280
+
281
+
282
+ def noise_like(shape, device, repeat=False):
283
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
284
+ noise = lambda: torch.randn(shape, device=device)
285
+ return repeat_noise() if repeat else noise()
ldm/modules/distributions/__init__.py ADDED
File without changes
ldm/modules/distributions/distributions.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ import torch
15
+ import numpy as np
16
+
17
+
18
+ class AbstractDistribution:
19
+ def sample(self):
20
+ raise NotImplementedError()
21
+
22
+ def mode(self):
23
+ raise NotImplementedError()
24
+
25
+
26
+ class DiracDistribution(AbstractDistribution):
27
+ def __init__(self, value):
28
+ self.value = value
29
+
30
+ def sample(self):
31
+ return self.value
32
+
33
+ def mode(self):
34
+ return self.value
35
+
36
+
37
+ class DiagonalGaussianDistribution(object):
38
+ def __init__(self, parameters, deterministic=False):
39
+ self.parameters = parameters
40
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
41
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
42
+ self.deterministic = deterministic
43
+ self.std = torch.exp(0.5 * self.logvar)
44
+ self.var = torch.exp(self.logvar)
45
+ if self.deterministic:
46
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
47
+
48
+ def sample(self):
49
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
50
+ return x
51
+
52
+ def kl(self, other=None):
53
+ if self.deterministic:
54
+ return torch.Tensor([0.])
55
+ else:
56
+ if other is None:
57
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
58
+ + self.var - 1.0 - self.logvar,
59
+ dim=[1, 2, 3])
60
+ else:
61
+ return 0.5 * torch.sum(
62
+ torch.pow(self.mean - other.mean, 2) / other.var
63
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
64
+ dim=[1, 2, 3])
65
+
66
+ def nll(self, sample, dims=[1,2,3]):
67
+ if self.deterministic:
68
+ return torch.Tensor([0.])
69
+ logtwopi = np.log(2.0 * np.pi)
70
+ return 0.5 * torch.sum(
71
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
72
+ dim=dims)
73
+
74
+ def mode(self):
75
+ return self.mean
76
+
77
+
78
+ def normal_kl(mean1, logvar1, mean2, logvar2):
79
+ """
80
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
81
+ Compute the KL divergence between two gaussians.
82
+ Shapes are automatically broadcasted, so batches can be compared to
83
+ scalars, among other use cases.
84
+ """
85
+ tensor = None
86
+ for obj in (mean1, logvar1, mean2, logvar2):
87
+ if isinstance(obj, torch.Tensor):
88
+ tensor = obj
89
+ break
90
+ assert tensor is not None, "at least one argument must be a Tensor"
91
+
92
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
93
+ # Tensors, but it does not work for torch.exp().
94
+ logvar1, logvar2 = [
95
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
96
+ for x in (logvar1, logvar2)
97
+ ]
98
+
99
+ return 0.5 * (
100
+ -1.0
101
+ + logvar2
102
+ - logvar1
103
+ + torch.exp(logvar1 - logvar2)
104
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
105
+ )
ldm/modules/ema.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ import torch
15
+ from torch import nn
16
+
17
+
18
+ class LitEma(nn.Module):
19
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
20
+ super().__init__()
21
+ if decay < 0.0 or decay > 1.0:
22
+ raise ValueError('Decay must be between 0 and 1')
23
+
24
+ self.m_name2s_name = {}
25
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
26
+ self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
27
+ else torch.tensor(-1,dtype=torch.int))
28
+
29
+ for name, p in model.named_parameters():
30
+ if p.requires_grad:
31
+ #remove as '.'-character is not allowed in buffers
32
+ s_name = name.replace('.','')
33
+ self.m_name2s_name.update({name:s_name})
34
+ self.register_buffer(s_name,p.clone().detach().data)
35
+
36
+ self.collected_params = []
37
+
38
+ def forward(self,model):
39
+ decay = self.decay
40
+
41
+ if self.num_updates >= 0:
42
+ self.num_updates += 1
43
+ decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
44
+
45
+ one_minus_decay = 1.0 - decay
46
+
47
+ with torch.no_grad():
48
+ m_param = dict(model.named_parameters())
49
+ shadow_params = dict(self.named_buffers())
50
+
51
+ for key in m_param:
52
+ if m_param[key].requires_grad:
53
+ sname = self.m_name2s_name[key]
54
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
55
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
56
+ else:
57
+ assert not key in self.m_name2s_name
58
+
59
+ def copy_to(self, model):
60
+ m_param = dict(model.named_parameters())
61
+ shadow_params = dict(self.named_buffers())
62
+ for key in m_param:
63
+ if m_param[key].requires_grad:
64
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
65
+ else:
66
+ assert not key in self.m_name2s_name
67
+
68
+ def store(self, parameters):
69
+ """
70
+ Save the current parameters for restoring later.
71
+ Args:
72
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73
+ temporarily stored.
74
+ """
75
+ self.collected_params = [param.clone() for param in parameters]
76
+
77
+ def restore(self, parameters):
78
+ """
79
+ Restore the parameters stored with the `store` method.
80
+ Useful to validate the model with EMA parameters without affecting the
81
+ original optimization process. Store the parameters before the
82
+ `copy_to` method. After validation (or model saving), use this to
83
+ restore the former parameters.
84
+ Args:
85
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
86
+ updated with the stored parameters.
87
+ """
88
+ for c_param, param in zip(self.collected_params, parameters):
89
+ param.data.copy_(c_param.data)
ldm/modules/encoders/__init__.py ADDED
File without changes