diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..efa407c35ff028586b7ef5456c537971fefa5cea --- /dev/null +++ b/.gitignore @@ -0,0 +1,162 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ \ No newline at end of file diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..cee55da041bf1c1f31278ef42bcffb19c32b3d86 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,79 @@ +# Adobe Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our project and community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, caste, color, religion, or sexual identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contribute to a positive environment for our project and community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience +* Focusing on what is best, not just for us as individuals but for the overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others’ private information, such as a physical or email address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for behaviors that they deem inappropriate, threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies when an individual is representing the project or its community both within project spaces and in public spaces. Examples of representing a project or community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by first contacting the project team. Oversight of Adobe projects is handled by the Adobe Open Source Office, which has final say in any violations and enforcement of this Code of Conduct and can be reached at Grp-opensourceoffice@adobe.com. All complaints will be reviewed and investigated promptly and fairly. + +The project team must respect the privacy and security of the reporter of any incident. + +Project maintainers who do not follow or enforce the Code of Conduct may face temporary or permanent repercussions as determined by other members of the project's leadership or the Adobe Open Source Office. + +## Enforcement Guidelines + +Project maintainers will follow these Community Impact Guidelines in determining the consequences for any action they deem to be in violation of this Code of Conduct: + +**1. Correction** + +Community Impact: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. + +Consequence: A private, written warning from project maintainers describing the violation and why the behavior was unacceptable. A public apology may be requested from the violator before any further involvement in the project by violator. + +**2. Warning** + +Community Impact: A relatively minor violation through a single incident or series of actions. + +Consequence: A written warning from project maintainers that includes stated consequences for continued unacceptable behavior. Violator must refrain from interacting with the people involved for a specified period of time as determined by the project maintainers, including, but not limited to, unsolicited interaction with those enforcing the Code of Conduct through channels such as community spaces and social media. Continued violations may lead to a temporary or permanent ban. + +**3. Temporary Ban** + +Community Impact: A more serious violation of community standards, including sustained unacceptable behavior. + +Consequence: A temporary ban from any interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Failure to comply with the temporary ban may lead to a permanent ban. + +**4. Permanent Ban** + +Community Impact: Demonstrating a consistent pattern of violation of community standards or an egregious violation of community standards, including, but not limited to, sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. + +Consequence: A permanent ban from any interaction with the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.1, +available at [https://contributor-covenant.org/version/2/1][version] + +[homepage]: https://contributor-covenant.org +[version]: https://contributor-covenant.org/version/2/1 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..a09069b3f8d9e99be1d4361e4353f1a6173cef62 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,47 @@ +# Contributing + +Thanks for choosing to contribute! + +The following are a set of guidelines to follow when contributing to this project. + +## Code Of Conduct + +This project adheres to the Adobe [code of conduct](./CODE_OF_CONDUCT.md). By participating, +you are expected to uphold this code. Please report unacceptable behavior to +[Grp-opensourceoffice@adobe.com](mailto:Grp-opensourceoffice@adobe.com). + +## Have A Question? + +Start by filing an issue. The existing committers on this project work to reach +consensus around project direction and issue solutions within issue threads +(when appropriate). + +## Contributor License Agreement + +All third-party contributions to this project must be accompanied by a signed contributor +license agreement. This gives Adobe permission to redistribute your contributions +as part of the project. [Sign our CLA](https://opensource.adobe.com/cla.html). You +only need to submit an Adobe CLA one time, so if you have submitted one previously, +you are good to go! + +## Code Reviews + +All submissions should come in the form of pull requests and need to be reviewed +by project committers. Read [GitHub's pull request documentation](https://help.github.com/articles/about-pull-requests/) +for more information on sending pull requests. + +Lastly, please follow the [pull request template](PULL_REQUEST_TEMPLATE.md) when +submitting a pull request! + +## From Contributor To Committer + +We love contributions from our community! If you'd like to go a step beyond contributor +and become a committer with full write access and a say in the project, you must +be invited to the project. The existing committers employ an internal nomination +process that must reach lazy consensus (silence is approval) before invitations +are issued. If you feel you are qualified and want to get more deeply involved, +feel free to reach out to existing committers to have a conversation about that. + +## Security Issues + +Security issues shouldn't be reported on this issue tracker. Instead, [file an issue to our security experts](https://helpx.adobe.com/security/alertus.html). diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..2ac9b47f9dd6893aa9a920838ff746df83578683 --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,13 @@ +Copyright 2024, Adobe Inc. and its licensors. All rights reserved. + +ADOBE RESEARCH LICENSE + +Adobe grants any person or entity ("you" or "your") obtaining a copy of these certain research materials that are owned by Adobe ("Licensed Materials") a nonexclusive, worldwide, royalty-free, revocable, fully paid license to (A) reproduce, use, modify, and publicly display the Licensed Materials; and (B) redistribute the Licensed Materials, and modifications or derivative works thereof, provided the following conditions are met: + +The rights granted herein may be exercised for noncommercial research purposes (i.e., academic research and teaching) only. Noncommercial research purposes do not include commercial licensing or distribution, development of commercial products, or any other activity that results in commercial gain. +You may add your own copyright statement to your modifications and/or provide additional or different license terms for use, reproduction, modification, public display, and redistribution of your modifications and derivative works, provided that such license terms limit the use, reproduction, modification, public display, and redistribution of such modifications and derivative works to noncommercial research purposes only. +You acknowledge that Adobe and its licensors own all right, title, and interest in the Licensed Materials. +All copies of the Licensed Materials must include the above copyright notice, this list of conditions, and the disclaimer below. +Failure to meet any of the above conditions will automatically terminate the rights granted herein. + +THE LICENSED MATERIALS ARE PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND. THE ENTIRE RISK AS TO THE USE, RESULTS, AND PERFORMANCE OF THE LICENSED MATERIALS IS ASSUMED BY YOU. ADOBE DISCLAIMS ALL WARRANTIES, EXPRESS, IMPLIED OR STATUTORY, WITH REGARD TO YOUR USE OF THE LICENSED MATERIALS, INCLUDING, BUT NOT LIMITED TO, NONINFRINGEMENT OF THIRD-PARTY RIGHTS. IN NO EVENT WILL ADOBE BE LIABLE FOR ANY ACTUAL, INCIDENTAL, SPECIAL OR CONSEQUENTIAL DAMAGES, INCLUDING WITHOUT LIMITATION, LOSS OF PROFITS OR OTHER COMMERCIAL LOSS, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THE LICENSED MATERIALS, EVEN IF ADOBE HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. \ No newline at end of file diff --git a/README.md b/README.md index 4ed47f2bb2c82804279b0c66fdb143a2090c5fa8..d5c45d5078e5d0ec136be94fa4167fca9f04eefd 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,70 @@ ---- -title: Oilkkkkbb -emoji: 🖼 -colorFrom: purple -colorTo: red -sdk: gradio -sdk_version: 4.26.0 -app_file: app.py -pinned: false -license: apache-2.0 ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# MagicFixup +This is the repo for the paper [Magic Fixup: Streamlining Photo Editing by Watching Dynamic Videos](https://magic-fixup.github.io) +## Installation +We provide an `environment.yaml` file to assist with installation. All what you need for setup is to run the following script +``` +conda env create -f environment.yaml -v +``` +and this will create a conda environment that you can activate using `conda activate MagicFixup` + +## Inference + +#### Downloading Magic Fixup checkpoint +You can download the model trained on the Moments in Time dataset using this [Google Drive Link](https://drive.google.com/file/d/1zOcDcJzCijbGr9I9adC0Cv6yzW60U9TQ/view?usp=share_link) + + +### Inference script +The inference scripts is `run_magicfu.py`. It takes the path of the reference image (the original image), and the edited image. Note that it assumes that the alpha channel is set appropriately in the edited image PNG, as we use the alpha channel to set the disocclusion mask. You can run the inference script with + +``` +python run_magicfu.py --checkpoint --reference --edit +``` + +### gradio demo +We have a gradio demo that allows you to test out your inputs with a friendly user interface. Simply start the demo with +``` +python magicfu_gradio.py --checkpoint +``` + + +## Training your own model +To train your own model, first you need to process a video dataset, train the model using the processed pairs from your videos. In our model, we used the Momnets in Time dataset to train the weights we provided above. + +#### Pretrained SD1.4 diffusion model +We start training from the official SD1.4 model (with the first layer modified to take our 9 channel input). You can either download the official SD1.4 model and modify the first layer using `scripts/modify_checkpoints.py` and place it under `pretrained_models` folder. + +### Data Processing +The data processing code can be found under the `data_processing` folder. You can simply put all the videos in a directory, and pass the directory as the folder name in `data_processing/moments_processing.py`. If your videos are long (~ex more than 5 seconds and contain cut scenes), then you would want to use pyscenedetect to detect cut scenes and split the videos accordingly. +For data processing, you also need to download the checkpoint for SegmentAnything, and install soft-splatting. You can setup softmax-splatting and SAM, by following +``` +cd data_processing +git clone https://github.com/sniklaus/softmax-splatting.git +pip install segment_anything +cd sam_model +wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth +``` +For softmax-splatting to run, you need to install `pip install cupy` (or you might need to use `pip install cupy-cuda11x` or `pip install cupy-cuda12x` depending on your cuda version, and load the appropriate cuda module) + +Then run `python moments_processing.py` to start processing frames from the provided examples video (included under `data_processing/example_videos`). For the version provided, we used the [Moments in Time Dataset](http://moments.csail.mit.edu) + +### Running the training script +Make sure that you have downloaded the pretrained SD1.4 model above. Once you download the training dataset and pretrained model, you can simply start training the model with +``` +./train.sh +``` +The training code is in `main.py`, and relies mainly on pytorch_lightning in training. + +Note that you need to modify the train and val paths in the chosen config file to the location where you have the processed data. + +Note: we use Deepspeed to lower the memory requirements, so the saved model weights will be sharded. The script to reconstruct the model weights will be created in the checkpoint directory with name `zero_to_fp32.py`. One bug in the file is that it wouldn't recognize files with deepspeed1 (which is the one we use), so simply find and replace the string `== 2` with the string `<= 2` and it will work. + +### Saving the Full Model Weights +To save storage requirements, we only checkpoint the learnable parameters in training (i.e. the frozen autoencoder params are not saved). To create a checkpoint that contains all the parameters, you can combine the frozen pretrained weights and learned parameters by running + +``` +python combine_model_params.py --pretrained_sd --learned_params --save_path +``` + + +##### Acknowledgement +The diffusion code was built on top of the codebase adapted in [PaintByExample](https://github.com/Fantasy-Studio/Paint-by-Example) \ No newline at end of file diff --git a/app.py b/app.py index 7fd71f3866eed8b72f1a4c5185fb07ff5de7cb8d..4f7cd13e12670a16b14efb5d1ad7f3ced5e7c7a2 100644 --- a/app.py +++ b/app.py @@ -1,146 +1,53 @@ +# Copyright 2024 Adobe. All rights reserved. + +from run_magicfu import MagicFixup +import os +import pathlib +import torchvision +from torch import autocast +from PIL import Image import gradio as gr import numpy as np -import random -from diffusers import DiffusionPipeline -import torch - -device = "cuda" if torch.cuda.is_available() else "cpu" - -if torch.cuda.is_available(): - torch.cuda.max_memory_allocated(device=device) - pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True) - pipe.enable_xformers_memory_efficient_attention() - pipe = pipe.to(device) -else: - pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True) - pipe = pipe.to(device) - -MAX_SEED = np.iinfo(np.int32).max -MAX_IMAGE_SIZE = 1024 - -def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps): - - if randomize_seed: - seed = random.randint(0, MAX_SEED) +import argparse + + +def sample(original_image, coarse_edit): + to_tensor = torchvision.transforms.ToTensor() + with autocast("cuda"): + w, h = coarse_edit.size + ref_image_t = to_tensor(original_image.resize((512,512))).half().cuda() + coarse_edit_t = to_tensor(coarse_edit.resize((512,512))).half().cuda() + # get mask from coarse + coarse_edit_mask_t = to_tensor(coarse_edit.resize((512,512))).half().cuda() + mask_t = (coarse_edit_mask_t[-1][None, None,...]).half() # do center crop + coarse_edit_t_rgb = coarse_edit_t[:-1] - generator = torch.Generator().manual_seed(seed) - - image = pipe( - prompt = prompt, - negative_prompt = negative_prompt, - guidance_scale = guidance_scale, - num_inference_steps = num_inference_steps, - width = width, - height = height, - generator = generator - ).images[0] - - return image - -examples = [ - "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", - "An astronaut riding a green horse", - "A delicious ceviche cheesecake slice", -] - -css=""" -#col-container { - margin: 0 auto; - max-width: 520px; -} -""" - -if torch.cuda.is_available(): - power_device = "GPU" -else: - power_device = "CPU" - -with gr.Blocks(css=css) as demo: + out_rgb = magic_fixup.edit_image(ref_image_t, coarse_edit_t_rgb, mask_t, start_step=1.0, steps=50) + output = out_rgb.squeeze().cpu().detach().moveaxis(0, -1).float().numpy() + output = (output * 255.0).astype(np.uint8) + output_pil = Image.fromarray(output) + output_pil = output_pil.resize((w, h)) + return output_pil + +def file_exists(path): + """ Check if a file exists and is not a directory. """ + if not os.path.isfile(path): + raise argparse.ArgumentTypeError(f"{path} is not a valid file.") + return path + +def parse_arguments(): + """ Parses command-line arguments. """ + parser = argparse.ArgumentParser(description="Process images based on provided paths.") + parser.add_argument("--checkpoint", type=file_exists, required=True, help="Path to the MagicFixup checkpoint file.") + + return parser.parse_args() + +demo = gr.Interface(fn=sample, inputs=[gr.Image(type="pil", image_mode='RGB'), gr.Image(type="pil", image_mode='RGBA')], outputs=gr.Image(), + examples='examples') - with gr.Column(elem_id="col-container"): - gr.Markdown(f""" - # Text-to-Image Gradio Template - Currently running on {power_device}. - """) - - with gr.Row(): - - prompt = gr.Text( - label="Prompt", - show_label=False, - max_lines=1, - placeholder="Enter your prompt", - container=False, - ) - - run_button = gr.Button("Run", scale=0) - - result = gr.Image(label="Result", show_label=False) - - with gr.Accordion("Advanced Settings", open=False): - - negative_prompt = gr.Text( - label="Negative prompt", - max_lines=1, - placeholder="Enter a negative prompt", - visible=False, - ) - - seed = gr.Slider( - label="Seed", - minimum=0, - maximum=MAX_SEED, - step=1, - value=0, - ) - - randomize_seed = gr.Checkbox(label="Randomize seed", value=True) - - with gr.Row(): - - width = gr.Slider( - label="Width", - minimum=256, - maximum=MAX_IMAGE_SIZE, - step=32, - value=512, - ) - - height = gr.Slider( - label="Height", - minimum=256, - maximum=MAX_IMAGE_SIZE, - step=32, - value=512, - ) - - with gr.Row(): - - guidance_scale = gr.Slider( - label="Guidance scale", - minimum=0.0, - maximum=10.0, - step=0.1, - value=0.0, - ) - - num_inference_steps = gr.Slider( - label="Number of inference steps", - minimum=1, - maximum=12, - step=1, - value=2, - ) - - gr.Examples( - examples = examples, - inputs = [prompt] - ) - - run_button.click( - fn = infer, - inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps], - outputs = [result] - ) +if __name__ == "__main__": + args = parse_arguments() -demo.queue().launch() \ No newline at end of file + # create magic fixup model + magic_fixup = MagicFixup(model_path=args.checkpoint) + demo.launch(share=True) diff --git a/configs/collage_composite_train.yaml b/configs/collage_composite_train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0687bd5f08200a1a7fabfb0c64cdb5b0a8195156 --- /dev/null +++ b/configs/collage_composite_train.yaml @@ -0,0 +1,114 @@ +# Copyright 2024 Adobe. All rights reserved. +model: + base_learning_rate: 1.0e-05 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "inpaint" + cond_stage_key: "image" + image_size: 64 + channels: 4 + cond_stage_trainable: true # Note: different from the one we trained before + conditioning_key: "rewarp" + monitor: val/loss_simple_ema + u_cond_percent: 0.2 + scale_factor: 0.18215 + use_ema: False + context_embedding_dim: 768 # TODO embedding # 1024 clip, DINO: 'small': 384,'big': 768,'large': 1024,'huge': 1536 + + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 9 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + add_conv_in_front_of_unet: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.DINOEmbedder # TODO embedding + params: + dino_version: "big" # [small, big, large, huge] + +data: + target: main.DataModuleFromConfig + params: + batch_size: 2 + num_workers: 8 + use_worker_init_fn: False + wrap: False + train: + target: ldm.data.collage_dataset.CollageDataset + params: + split_files: "" + image_size: 512 + embedding_type: 'dino' # TODO embedding + warping_type: 'collage' + validation: + target: ldm.data.collage_dataset.CollageDataset + params: + split_files: "" + image_size: 512 + embedding_type: 'dino' # TODO embedding + warping_type: 'mix' + test: + target: ldm.data.collage_dataset.CollageDataset + params: + split_files: "" + image_size: 512 + embedding_type: 'dino' # TODO embedding + warping_type: 'mix' + +lightning: + trainer: + max_epochs: 500 + num_nodes: 1 + num_sanity_val_steps: 0 + accelerator: 'gpu' + gpus: "0,1,2,3,4,5,6,7" diff --git a/configs/collage_flow_train.yaml b/configs/collage_flow_train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aded8a07fad4d5d1f3c0c3a6e5d64be8fa155573 --- /dev/null +++ b/configs/collage_flow_train.yaml @@ -0,0 +1,114 @@ +# Copyright 2024 Adobe. All rights reserved. +model: + base_learning_rate: 1.0e-05 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "inpaint" + cond_stage_key: "image" + image_size: 64 + channels: 4 + cond_stage_trainable: true # Note: different from the one we trained before + conditioning_key: "rewarp" + monitor: val/loss_simple_ema + u_cond_percent: 0.2 + scale_factor: 0.18215 + use_ema: False + context_embedding_dim: 768 # TODO embedding # 1024 clip, DINO: 'small': 384,'big': 768,'large': 1024,'huge': 1536 + + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 9 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + add_conv_in_front_of_unet: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.DINOEmbedder # TODO embedding + params: + dino_version: "big" # [small, big, large, huge] + +data: + target: main.DataModuleFromConfig + params: + batch_size: 2 + num_workers: 8 + use_worker_init_fn: False + wrap: False + train: + target: ldm.data.collage_dataset.CollageDataset + params: + split_files: /mnt/localssd/new_train + image_size: 512 + embedding_type: 'dino' # TODO embedding + warping_type: 'flow' + validation: + target: ldm.data.collage_dataset.CollageDataset + params: + split_files: /mnt/localssd/new_val + image_size: 512 + embedding_type: 'dino' # TODO embedding + warping_type: 'mix' + test: + target: ldm.data.collage_dataset.CollageDataset + params: + split_files: /mnt/localssd/new_val + image_size: 512 + embedding_type: 'dino' # TODO embedding + warping_type: 'mix' + +lightning: + trainer: + max_epochs: 500 + num_nodes: 1 + num_sanity_val_steps: 0 + accelerator: 'gpu' + gpus: "0,1,2,3,4,5,6,7" diff --git a/configs/collage_mix_train.yaml b/configs/collage_mix_train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2b76211309306634700accb0db60ff3c90a92535 --- /dev/null +++ b/configs/collage_mix_train.yaml @@ -0,0 +1,115 @@ +# Copyright 2024 Adobe. All rights reserved. +model: + base_learning_rate: 1.0e-05 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "inpaint" + cond_stage_key: "image" + image_size: 64 + channels: 4 + cond_stage_trainable: true # Note: different from the one we trained before + conditioning_key: "rewarp" + monitor: val/loss_simple_ema + u_cond_percent: 0.2 + scale_factor: 0.18215 + use_ema: False + context_embedding_dim: 384 # TODO embedding # 1024 clip, DINO: 'small': 384,'big': 768,'large': 1024,'huge': 1536 + dropping_warped_latent_prob: 0.2 + + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 9 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + add_conv_in_front_of_unet: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.DINOEmbedder # TODO embedding + params: + dino_version: "small" # [small, big, large, huge] + +data: + target: main.DataModuleFromConfig + params: + batch_size: 4 + num_workers: 8 + use_worker_init_fn: False + wrap: False + train: + target: ldm.data.collage_dataset.CollageDataset + params: + split_files: /mnt/localssd/new_train + image_size: 512 + embedding_type: 'dino' # TODO embedding + warping_type: 'mix' + validation: + target: ldm.data.collage_dataset.CollageDataset + params: + split_files: /mnt/localssd/new_val + image_size: 512 + embedding_type: 'dino' # TODO embedding + warping_type: 'mix' + test: + target: ldm.data.collage_dataset.CollageDataset + params: + split_files: /mnt/localssd/new_val + image_size: 512 + embedding_type: 'dino' # TODO embedding + warping_type: 'mix' + +lightning: + trainer: + max_epochs: 500 + num_nodes: 1 + num_sanity_val_steps: 0 + accelerator: 'gpu' + gpus: "0,1,2,3,4,5,6,7" diff --git a/data_processing/example_videos/getty-soccer-ball-jordan-video-id473239807_26.mp4 b/data_processing/example_videos/getty-soccer-ball-jordan-video-id473239807_26.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..5e8ac53690574214c9fc71bfbdde26e82d999b8b Binary files /dev/null and b/data_processing/example_videos/getty-soccer-ball-jordan-video-id473239807_26.mp4 differ diff --git a/data_processing/example_videos/getty-video-of-american-flags-being-sewn-together-at-flagsource-in-batavia-video-id804937470_87.mp4 b/data_processing/example_videos/getty-video-of-american-flags-being-sewn-together-at-flagsource-in-batavia-video-id804937470_87.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..06fe4aad23df2f9ddb5ffad5acf336bae2490d36 Binary files /dev/null and b/data_processing/example_videos/getty-video-of-american-flags-being-sewn-together-at-flagsource-in-batavia-video-id804937470_87.mp4 differ diff --git a/data_processing/example_videos/giphy-fgiT2cbsTxl8k_0.mp4 b/data_processing/example_videos/giphy-fgiT2cbsTxl8k_0.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..1962a5de60929b1b489239bb5583917620b9efcc Binary files /dev/null and b/data_processing/example_videos/giphy-fgiT2cbsTxl8k_0.mp4 differ diff --git a/data_processing/example_videos/giphy-gkvCpHRX9IqkM_3.mp4 b/data_processing/example_videos/giphy-gkvCpHRX9IqkM_3.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..896e0f5bbe968b2f882f501a711ef5b19ba5a7ac Binary files /dev/null and b/data_processing/example_videos/giphy-gkvCpHRX9IqkM_3.mp4 differ diff --git a/data_processing/example_videos/yt--4Fx5XUD-9Y_345.mp4 b/data_processing/example_videos/yt--4Fx5XUD-9Y_345.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..3c1d8ea0ca8e466b7abd0166b271f0b6c9a2ddcf Binary files /dev/null and b/data_processing/example_videos/yt--4Fx5XUD-9Y_345.mp4 differ diff --git a/data_processing/example_videos/yt-mNdvtOO7UqY_15.mp4 b/data_processing/example_videos/yt-mNdvtOO7UqY_15.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..f9d6d28a435f4625a2dea25b4ef9d14a7377137b Binary files /dev/null and b/data_processing/example_videos/yt-mNdvtOO7UqY_15.mp4 differ diff --git a/data_processing/moments_dataset.py b/data_processing/moments_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7cf988b3a81da05f321fc34cfb7f2d36328c9a00 --- /dev/null +++ b/data_processing/moments_dataset.py @@ -0,0 +1,54 @@ +# Copyright 2024 Adobe. All rights reserved. + +#%% +import glob +import torch +import torchvision +import matplotlib.pyplot as plt +from torch.utils.data import Dataset +import numpy as np + + +# %% +class MomentsDataset(Dataset): + def __init__(self, videos_folder, num_frames, samples_per_video, frame_size=512) -> None: + super().__init__() + + self.videos_paths = glob.glob(f'{videos_folder}/*mp4') + self.resize = torchvision.transforms.Resize(size=frame_size) + self.center_crop = torchvision.transforms.CenterCrop(size=frame_size) + self.num_samples_per_video = samples_per_video + self.num_frames = num_frames + + def __len__(self): + return len(self.videos_paths) * self.num_samples_per_video + + def __getitem__(self, idx): + video_idx = idx // self.num_samples_per_video + video_path = self.videos_paths[video_idx] + + try: + start_idx = np.random.randint(0, 20) + + unsampled_video_frames, audio_frames, info = torchvision.io.read_video(video_path,output_format="TCHW") + sampled_indices = torch.tensor(np.linspace(start_idx, len(unsampled_video_frames)-1, self.num_frames).astype(int)) + sampled_frames = unsampled_video_frames[sampled_indices] + processed_frames = [] + + for frame in sampled_frames: + resized_cropped_frame = self.center_crop(self.resize(frame)) + processed_frames.append(resized_cropped_frame) + frames = torch.stack(processed_frames, dim=0) + frames = frames.float() / 255.0 + except Exception as e: + print('oops', e) + rand_idx = np.random.randint(0, len(self)) + return self.__getitem__(rand_idx) + + out_dict = {'frames': frames, + 'caption': 'none', + 'keywords': 'none'} + + return out_dict + + diff --git a/data_processing/moments_processing.py b/data_processing/moments_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..a7ec7522c29bf3f151e2c694b2966ea24114a11a --- /dev/null +++ b/data_processing/moments_processing.py @@ -0,0 +1,359 @@ +# Copyright 2024 Adobe. All rights reserved. + +#%% +from torchvision.transforms import ToPILImage +import torch +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np +import torchvision +import cv2 +import tqdm +import matplotlib.pyplot as plt +import torchvision.transforms.functional as F +from PIL import Image +from torchvision.utils import save_image +import time +import os +import sys +import pathlib +from torchvision.utils import flow_to_image +from torch.utils.data import DataLoader +from einops import rearrange +# %matplotlib inline +from kornia.filters.median import MedianBlur +median_filter = MedianBlur(kernel_size=(15,15)) +from moments_dataset import MomentsDataset + +try: + from processing_utils import aggregate_frames + import processing_utils +except Exception as e: + print(e) + print('process failed') + exit() + + + + +import pytorch_lightning as pl +import torch +from omegaconf import OmegaConf + +# %% + +def load_image(img_path, resize_size=None,crop_size=None): + + img1_pil = Image.open(img_path) + img1_frames = torchvision.transforms.functional.pil_to_tensor(img1_pil) + + if resize_size: + img1_frames = torchvision.transforms.functional.resize(img1_frames, resize_size) + + if crop_size: + img1_frames = torchvision.transforms.functional.center_crop(img1_frames, crop_size) + + img1_batch = torch.unsqueeze(img1_frames, dim=0) + + return img1_batch + +def get_grid(size): + y = np.repeat(np.arange(size)[None, ...], size) + y = y.reshape(size, size) + x = y.transpose() + out = np.stack([y,x], -1) + return out + +def collage_from_frames(frames_t): + # decide forward or backward + if np.random.randint(0, 2) == 0: + # flip + frames_t = frames_t.flip(0) + + # decide how deep you would go + tgt_idx_guess = np.random.randint(1, min(len(frames_t), 20)) + tgt_idx = 1 + pairwise_flows = [] + flow = None + init_time = time.time() + unsmoothed_agg = None + for cur_idx in range(1, tgt_idx_guess+1): + # cur_idx = i+1 + cur_flow, pairwise_flows = aggregate_frames(frames_t[:cur_idx+1] , pairwise_flows, unsmoothed_agg) # passing pairwise flows for efficiency + unsmoothed_agg = cur_flow.clone() + agg_cur_flow = median_filter(cur_flow) + + flow_norm = torch.norm(agg_cur_flow.squeeze(), dim=0).flatten() + # flow_10 = np.percentile(flow_norm.cpu().numpy(), 10) + flow_90 = np.percentile(flow_norm.cpu().numpy(), 90) + + # flow_10 = np.percentile(flow_norm.cpu().numpy(), 10) + flow_90 = np.percentile(flow_norm.cpu().numpy(), 90) + flow_95 = np.percentile(flow_norm.cpu().numpy(), 95) + + if cur_idx == 5: # if still small flow then drop + if flow_95 < 20.0: + # no motion in the frame. skip + print('flow is tiny :(') + return None + + if cur_idx == tgt_idx_guess-1: # if still small flow then drop + if flow_95 < 50.0: + # no motion in the frame. skip + print('flow is tiny :(') + return None + + if flow is None: # means first iter + if flow_90 < 1.0: + # no motion in the frame. skip + return None + flow = agg_cur_flow + + if flow_90 <= 300: # maybe should increase this part + # update idx + tgt_idx = cur_idx + flow = agg_cur_flow + else: + break + final_time = time.time() + print('time guessing idx', final_time - init_time) + + _, flow_warping_mask = processing_utils.forward_warp(frames_t[0], frames_t[tgt_idx], flow, grid=None, alpha_mask=None) + flow_warping_mask = flow_warping_mask.squeeze().numpy() > 0.5 + + if np.mean(flow_warping_mask) < 0.6: + return + + + src_array = frames_t[0].moveaxis(0, -1).cpu().numpy() * 1.0 + init_time = time.time() + depth = get_depth_from_array(frames_t[0]) + finish_time = time.time() + print('time getting depth', finish_time - init_time) + # flow, pairwise_flows = aggregate_frames(frames_t) + # agg_flow = median_filter(flow) + + src_array_uint = src_array * 255.0 + src_array_uint = src_array_uint.astype(np.uint8) + segments = processing_utils.mask_generator.generate(src_array_uint) + + size = src_array.shape[1] + grid_np = get_grid(size).astype(np.float16) / size # 512 x 512 x 2get + grid_t = torch.tensor(grid_np).moveaxis(-1, 0) # 512 x 512 x 2 + + + collage, canvas_alpha, lost_alpha = collage_warp(src_array, flow.squeeze(), depth, segments, grid_array=grid_np) + lost_alpha_t = torch.tensor(lost_alpha).squeeze().unsqueeze(0) + warping_alpha = (lost_alpha_t < 0.5).float() + + rgb_grid_splatted, actual_warped_mask = processing_utils.forward_warp(frames_t[0], frames_t[tgt_idx], flow, grid=grid_t, alpha_mask=warping_alpha) + + + # basic blending now + # print('rgb grid splatted', rgb_grid_splatted.shape) + warped_src = (rgb_grid_splatted * actual_warped_mask).moveaxis(0, -1).cpu().numpy() + canvas_alpha_mask = canvas_alpha == 0.0 + collage_mask = canvas_alpha.squeeze() + actual_warped_mask.squeeze().cpu().numpy() + collage_mask = collage_mask > 0.5 + + composite_grid = warped_src * canvas_alpha_mask + collage + rgb_grid_splatted_np = rgb_grid_splatted.moveaxis(0, -1).cpu().numpy() + + return frames_t[0], frames_t[tgt_idx], rgb_grid_splatted_np, composite_grid, flow_warping_mask, collage_mask + +def collage_warp(rgb_array, flow, depth, segments, grid_array): + avg_depths = [] + avg_flows = [] + + # src_array = src_array.moveaxis(-1, 0).cpu().numpy() #np.array(Image.open(src_path).convert('RGB')) / 255.0 + src_array = np.concatenate([rgb_array, grid_array], axis=-1) + canvas = np.zeros_like(src_array) + canvas_alpha = np.zeros_like(canvas[...,-1:]).astype(float) + lost_regions = np.zeros_like(canvas[...,-1:]).astype(float) + z_buffer = np.ones_like(depth)[..., None] * -1.0 + unsqueezed_depth = depth[..., None] + + affine_transforms = [] + + filtered_segments = [] + for segment in segments: + if segment['area'] > 300: + filtered_segments.append(segment) + + for segment in filtered_segments: + seg_mask = segment['segmentation'] + avg_flow = torch.mean(flow[:, seg_mask],dim=1) + avg_flows.append(avg_flow) + # median depth (conversion from disparity) + avg_depth = torch.median(1.0 / (depth[seg_mask] + 1e-6)) + avg_depths.append(avg_depth) + + all_y, all_x = np.nonzero(segment['segmentation']) + rand_indices = np.random.randint(0, len(all_y), size=50) + rand_x = [all_x[i] for i in rand_indices] + rand_y = [all_y[i] for i in rand_indices] + + src_pairs = [(x, y) for x, y in zip(rand_x, rand_y)] + # tgt_pairs = [(x + w, y) for x, y in src_pairs] + tgt_pairs = [] + # print('estimating affine') # TODO this can be faster + for i in range(len(src_pairs)): + x, y = src_pairs[i] + dx, dy = flow[:, y, x] + tgt_pairs.append((x+dx, y+dy)) + + # affine_trans, inliers = cv2.estimateAffine2D(np.array(src_pairs).astype(np.float32), np.array(tgt_pairs).astype(np.float32)) + affine_trans, inliers = cv2.estimateAffinePartial2D(np.array(src_pairs).astype(np.float32), np.array(tgt_pairs).astype(np.float32)) + # print('num inliers', np.sum(inliers)) + # # print('num inliers', np.sum(inliers)) + affine_transforms.append(affine_trans) + + depth_sorted_indices = np.arange(len(avg_depths)) + depth_sorted_indices = sorted(depth_sorted_indices, key=lambda x: avg_depths[x]) + # sorted_masks = [] + # print('warping stuff') + for idx in depth_sorted_indices: + # sorted_masks.append(mask[idx]) + alpha_mask = filtered_segments[idx]['segmentation'][..., None] * (lost_regions < 0.5).astype(float) + src_rgba = np.concatenate([src_array, alpha_mask, unsqueezed_depth], axis=-1) + warp_dst = cv2.warpAffine(src_rgba, affine_transforms[idx], (src_array.shape[1], src_array.shape[0])) + warped_mask = warp_dst[..., -2:-1] # this is warped alpha + warped_depth = warp_dst[..., -1:] + warped_rgb = warp_dst[...,:-2] + + good_z_region = warped_depth > z_buffer + + warped_mask = np.logical_and(warped_mask > 0.5, good_z_region).astype(float) + + kernel = np.ones((3,3), float) + # print('og masked shape', warped_mask.shape) + # warped_mask = cv2.erode(warped_mask,(5,5))[..., None] + # print('eroded masked shape', warped_mask.shape) + canvas_alpha += cv2.erode(warped_mask,kernel)[..., None] + + lost_regions += alpha_mask + canvas = canvas * (1.0 - warped_mask) + warped_mask * warped_rgb # TODO check if need to dialate here + z_buffer = z_buffer * (1.0 - warped_mask) + warped_mask * warped_depth # TODO check if need to dialate here # print('max lost region', np.max(lost_regions)) + return canvas, canvas_alpha, lost_regions + +def get_depth_from_array(img_t): + img_arr = img_t.moveaxis(0, -1).cpu().numpy() * 1.0 + # print(img_arr.shape) + img_arr *= 255.0 + img_arr = img_arr.astype(np.uint8) + input_batch = processing_utils.depth_transform(img_arr).cuda() + + with torch.no_grad(): + prediction = processing_utils.midas(input_batch) + + prediction = torch.nn.functional.interpolate( + prediction.unsqueeze(1), + size=img_arr.shape[:2], + mode="bicubic", + align_corners=False, + ).squeeze() + + output = prediction.cpu() + return output + + +# %% + +def main(): + print('starting main') + video_folder = './example_videos' + save_dir = pathlib.Path('./processed_data') + process_video_folder(video_folder, save_dir) + +def process_video_folder(video_folder, save_dir): + all_counter = 0 + success_counter = 0 + + # save_folder = pathlib.Path('/dev/shm/processed') + # save_dir = save_folder / foldername #pathlib.Path('/sensei-fs/users/halzayer/collage2photo/testing_partitioning_dilate_extreme') + os.makedirs(save_dir, exist_ok=True) + + dataset = MomentsDataset(videos_folder=video_folder, num_frames=20, samples_per_video=5) + batch_size = 4 + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + + with torch.no_grad(): + for i, batch in tqdm.tqdm(enumerate(dataloader), total=len(dataset)//batch_size): + frames_to_visualize = batch["frames"] + bs = frames_to_visualize.shape[0] + + for j in range(bs): + frames = frames_to_visualize[j] + caption = batch["caption"][j] + + collage_init_time = time.time() + out = collage_from_frames(frames) + collage_finish_time = time.time() + print('collage processing time', collage_finish_time - collage_init_time) + all_counter += 1 + if out is not None: + src_image, tgt_image, splatted, collage, flow_mask, collage_mask = out + + splatted_rgb = splatted[...,:3] + splatted_grid = splatted[...,3:].astype(np.float16) + + collage_rgb = collage[...,:3] + collage_grid = collage[...,3:].astype(np.float16) + success_counter += 1 + else: + continue + + id_str = f'{success_counter:08d}' + + src_path = str(save_dir / f'src_{id_str}.png') + tgt_path = str(save_dir / f'tgt_{id_str}.png') + flow_warped_path = str(save_dir / f'flow_warped_{id_str}.png') + composite_path = str(save_dir / f'composite_{id_str}.png') + flow_mask_path = str(save_dir / f'flow_mask_{id_str}.png') + composite_mask_path = str(save_dir / f'composite_mask_{id_str}.png') + + flow_grid_path = str(save_dir / f'flow_warped_grid_{id_str}.npy') + composite_grid_path = str(save_dir / f'composite_grid_{id_str}.npy') + + save_image(src_image, src_path) + save_image(tgt_image, tgt_path) + + collage_pil = Image.fromarray((collage_rgb * 255).astype(np.uint8)) + collage_pil.save(composite_path) + + splatted_pil = Image.fromarray((splatted_rgb * 255).astype(np.uint8)) + splatted_pil.save(flow_warped_path) + + flow_mask_pil = Image.fromarray((flow_mask.astype(float) * 255).astype(np.uint8)) + flow_mask_pil.save(flow_mask_path) + + composite_mask_pil = Image.fromarray((collage_mask.astype(float) * 255).astype(np.uint8)) + composite_mask_pil.save(composite_mask_path) + + splatted_grid_t = torch.tensor(splatted_grid).moveaxis(-1, 0) + splatted_grid_resized = torchvision.transforms.functional.resize(splatted_grid_t, (64,64)) + + collage_grid_t = torch.tensor(collage_grid).moveaxis(-1, 0) + collage_grid_resized = torchvision.transforms.functional.resize(collage_grid_t, (64,64)) + np.save(flow_grid_path, splatted_grid_resized.cpu().numpy()) + np.save(composite_grid_path, collage_grid_resized.cpu().numpy()) + + + del out + del splatted_grid + del collage_grid + del frames + + del frames_to_visualize + + + +#%% + +if __name__ == '__main__': + try: + main() + except Exception as e: + print(e) + print('process failed') + diff --git a/data_processing/processing_utils.py b/data_processing/processing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c640fee2248afbc538776e6933ac8a0885c62961 --- /dev/null +++ b/data_processing/processing_utils.py @@ -0,0 +1,304 @@ +import torch +import cv2 +import numpy as np +import sys +import torchvision +from PIL import Image +from torchvision.models.optical_flow import Raft_Large_Weights +from torchvision.models.optical_flow import raft_large +from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor +import matplotlib.pyplot as plt +import torchvision.transforms.functional as F +sys.path.append('./softmax-splatting') +import softsplat + + +sam_checkpoint = "./sam_model/sam_vit_h_4b8939.pth" +model_type = "vit_h" + +device = "cuda" + +sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) +sam.to(device=device) +# mask_generator = SamAutomaticMaskGenerator(sam, +# crop_overlap_ratio=0.05, +# box_nms_thresh=0.2, +# points_per_side=32, +# pred_iou_thresh=0.86, +# stability_score_thresh=0.8, + +# min_mask_region_area=100,) +# mask_generator = SamAutomaticMaskGenerator(sam) +mask_generator = SamAutomaticMaskGenerator(sam, + # box_nms_thresh=0.5, + # crop_overlap_ratio=0.75, + # min_mask_region_area=200, + ) + +def get_mask(img_path): + image = cv2.imread(img_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + masks = mask_generator.generate(image) + return masks + +def get_mask_from_array(arr): + return mask_generator.generate(arr) + +# depth model + +import cv2 +import torch +import urllib.request + +import matplotlib.pyplot as plt + +# potentially downgrade this. just need rough depths. benchmark this +model_type = "DPT_Large" # MiDaS v3 - Large (highest accuracy, slowest inference speed) +#model_type = "DPT_Hybrid" # MiDaS v3 - Hybrid (medium accuracy, medium inference speed) +#model_type = "MiDaS_small" # MiDaS v2.1 - Small (lowest accuracy, highest inference speed) + +# midas = torch.hub.load("intel-isl/MiDaS", model_type) +midas = torch.hub.load("/sensei-fs/users/halzayer/collage2photo/model_cache/intel-isl_MiDaS_master", model_type, source='local') + +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") +midas.to(device) +midas.eval() + +# midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") +midas_transforms = torch.hub.load("/sensei-fs/users/halzayer/collage2photo/model_cache/intel-isl_MiDaS_master", "transforms", source='local') + +if model_type == "DPT_Large" or model_type == "DPT_Hybrid": + depth_transform = midas_transforms.dpt_transform +else: + depth_transform = midas_transforms.small_transform + +# img_path = '/sensei-fs/users/halzayer/valid/JPEGImages/45597680/00005.jpg' +def get_depth(img_path): + img = cv2.imread(img_path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + input_batch = depth_transform(img).to(device) + + with torch.no_grad(): + prediction = midas(input_batch) + + prediction = torch.nn.functional.interpolate( + prediction.unsqueeze(1), + size=img.shape[:2], + mode="bicubic", + align_corners=False, + ).squeeze() + + output = prediction.cpu() + return output + +def get_depth_from_array(img): + input_batch = depth_transform(img).to(device) + + with torch.no_grad(): + prediction = midas(input_batch) + + prediction = torch.nn.functional.interpolate( + prediction.unsqueeze(1), + size=img.shape[:2], + mode="bicubic", + align_corners=False, + ).squeeze() + + output = prediction.cpu() + return output + + +def load_image(img_path): + img1_names = [img_path] + + img1_pil = [Image.open(fn) for fn in img1_names] + img1_frames = [torchvision.transforms.functional.pil_to_tensor(fn) for fn in img1_pil] + + img1_batch = torch.stack(img1_frames) + + return img1_batch + +weights = Raft_Large_Weights.DEFAULT +transforms = weights.transforms() + +device = "cuda" if torch.cuda.is_available() else "cpu" + +model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device) +model = model.eval() + +print('created model') + +def preprocess(img1_batch, img2_batch, size=[520,960], transform_batch=True): + img1_batch = F.resize(img1_batch, size=size, antialias=False) + img2_batch = F.resize(img2_batch, size=size, antialias=False) + if transform_batch: + return transforms(img1_batch, img2_batch) + else: + return img1_batch, img2_batch + +def compute_flow(img_path_1, img_path_2): + img1_batch_og, img2_batch_og = load_image(img_path_1), load_image(img_path_2) + B, C, H, W = img1_batch_og.shape + + img1_batch, img2_batch = preprocess(img1_batch_og, img2_batch_og, transform_batch=False) + img1_batch_t, img2_batch_t = transforms(img1_batch, img2_batch) + + # If you can, run this example on a GPU, it will be a lot faster. + with torch.no_grad(): + list_of_flows = model(img1_batch_t.to(device), img2_batch_t.to(device)) + predicted_flows = list_of_flows[-1] + # flows.append(predicted_flows) + + resized_flow = F.resize(predicted_flows, size=(H, W), antialias=False) + + _, _, flow_H, flow_W = predicted_flows.shape + + resized_flow[:,0] *= (W / flow_W) + resized_flow[:,1] *= (H / flow_H) + + return resized_flow.detach().cpu().squeeze() + +def compute_flow_from_tensors(img1_batch_og, img2_batch_og): + if len(img1_batch_og.shape) < 4: + img1_batch_og = img1_batch_og.unsqueeze(0) + if len(img2_batch_og.shape) < 4: + img2_batch_og = img2_batch_og.unsqueeze(0) + + B, C, H, W = img1_batch_og.shape + img1_batch, img2_batch = preprocess(img1_batch_og, img2_batch_og, transform_batch=False) + img1_batch_t, img2_batch_t = transforms(img1_batch, img2_batch) + + # If you can, run this example on a GPU, it will be a lot faster. + with torch.no_grad(): + list_of_flows = model(img1_batch_t.to(device), img2_batch_t.to(device)) + predicted_flows = list_of_flows[-1] + # flows.append(predicted_flows) + + resized_flow = F.resize(predicted_flows, size=(H, W), antialias=False) + + _, _, flow_H, flow_W = predicted_flows.shape + + resized_flow[:,0] *= (W / flow_W) + resized_flow[:,1] *= (H / flow_H) + + return resized_flow.detach().cpu().squeeze() + + + +# import run +backwarp_tenGrid = {} + +def backwarp(tenIn, tenFlow): + if str(tenFlow.shape) not in backwarp_tenGrid: + tenHor = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[3], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, 1, -1).repeat(1, 1, tenFlow.shape[2], 1) + tenVer = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[2], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, -1, 1).repeat(1, 1, 1, tenFlow.shape[3]) + + backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([tenHor, tenVer], 1).cuda() + # end + + tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenIn.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenIn.shape[2] - 1.0) / 2.0)], 1) + + return torch.nn.functional.grid_sample(input=tenIn, grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=True) + +torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance + +########################################################## +def forward_splt(src, tgt, flow, partial=False): + tenTwo = tgt.unsqueeze(0).cuda() #torch.FloatTensor(numpy.ascontiguousarray(cv2.imread(filename='./images/one.png', flags=-1).transpose(2, 0, 1)[None, :, :, :].astype(numpy.float32) * (1.0 / 255.0))).cuda() + tenOne = src.unsqueeze(0).cuda() #torch.FloatTensor(numpy.ascontiguousarray(cv2.imread(filename='./images/two.png', flags=-1).transpose(2, 0, 1)[None, :, :, :].astype(numpy.float32) * (1.0 / 255.0))).cuda() + tenFlow = flow.unsqueeze(0).cuda() #torch.FloatTensor(numpy.ascontiguousarray(run.read_flo('./images/flow.flo').transpose(2, 0, 1)[None, :, :, :])).cuda() + + if not partial: + tenMetric = torch.nn.functional.l1_loss(input=tenOne, target=backwarp(tenIn=tenTwo, tenFlow=tenFlow), reduction='none').mean([1], True) + else: + tenMetric = torch.nn.functional.l1_loss(input=tenOne[:,:3], target=backwarp(tenIn=tenTwo[:,:3], tenFlow=tenFlow[:,:3]), reduction='none').mean([1], True) + # for intTime, fltTime in enumerate(np.linspace(0.0, 1.0, 11).tolist()): + tenSoftmax = softsplat.softsplat(tenIn=tenOne, tenFlow=tenFlow , tenMetric=(-20.0 * tenMetric).clip(-20.0, 20.0), strMode='soft') # -20.0 is a hyperparameter, called 'alpha' in the paper, that could be learned using a torch.Parameter + + return tenSoftmax.cpu() + + +def aggregate_frames(frames, pairwise_flows=None, agg_flow=None): + if pairwise_flows is None: + # store pairwise flows + pairwise_flows = [] + + if agg_flow is None: + start_idx = 0 + else: + start_idx = len(pairwise_flows) + + og_image = frames[start_idx] + prev_frame = og_image + + for i in range(start_idx, len(frames)-1): + tgt_frame = frames[i+1] + + if i < len(pairwise_flows): + flow = pairwise_flows[i] + else: + flow = compute_flow_from_tensors(prev_frame, tgt_frame) + pairwise_flows.append(flow.clone()) + + _, H, W = flow.shape + B=1 + + xx = torch.arange(0, W).view(1,-1).repeat(H,1) + + yy = torch.arange(0, H).view(-1,1).repeat(1,W) + + xx = xx.view(1,1,H,W).repeat(B,1,1,1) + + yy = yy.view(1,1,H,W).repeat(B,1,1,1) + + grid = torch.cat((xx,yy),1).float() + + flow = flow.unsqueeze(0) + if agg_flow is None: + agg_flow = torch.zeros_like(flow) + + vgrid = grid + agg_flow + vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1) - 1 + + vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1) - 1 + + flow_out = torch.nn.functional.grid_sample(flow, vgrid.permute(0,2,3,1), 'nearest') + + agg_flow += flow_out + + + # mask = forward_splt(torch.ones_like(og_image), torch.ones_like(og_image), agg_flow.squeeze()).squeeze() + # blur_t = torchvision.transforms.GaussianBlur(kernel_size=(25,25), sigma=5.0) + # warping_mask = (blur_t(mask)[0:1] > 0.8) + # masks.append(warping_mask) + prev_frame = tgt_frame + + return agg_flow, pairwise_flows #og_splatted_img, agg_flow, actual_warped_mask + + +def forward_warp(src_frame, tgt_frame, flow, grid=None, alpha_mask=None): + if alpha_mask is None: + alpha_mask = torch.ones_like(src_frame[:1]) + + if grid is not None: + src_list = [src_frame, grid, alpha_mask] + tgt_list = [tgt_frame, grid, alpha_mask] + else: + src_list = [src_frame, alpha_mask] + tgt_list = [tgt_frame, alpha_mask] + + og_image_padded = torch.concat(src_list, dim=0) + tgt_frame_padded = torch.concat(tgt_list, dim=0) + + og_splatted_img = forward_splt(og_image_padded, tgt_frame_padded, flow.squeeze(), partial=True).squeeze() + # print('og splatted image shape') + # grid_transformed = og_splatted_img[3:-1] + # print('grid transformed shape', grid_transformed) + + # grid *= grid_size + # grid_transformed *= grid_size + actual_warped_mask = og_splatted_img[-1:] + splatted_rgb_grid = og_splatted_img[:-1] + + return splatted_rgb_grid, actual_warped_mask \ No newline at end of file diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d4ee761cbf8419332c7f4bcca98ab357ff94c804 --- /dev/null +++ b/environment.yaml @@ -0,0 +1,33 @@ +name: MagicFixup +channels: + - pytorch + - defaults +dependencies: + - python=3.8.5 + - pip=20.3 + - cudatoolkit=11.3 + - pytorch=1.11.0 + - torchvision=0.12.0 + - numpy=1.19.2 + - pip: + - albumentations==0.4.3 + - diffusers + - bezier + - gradio + - opencv-python==4.1.2.30 + - pudb==2019.2 + - invisible-watermark + - imageio==2.9.0 + - imageio-ffmpeg==0.4.2 + - pytorch-lightning==2.0.0 + - omegaconf==2.1.1 + - test-tube>=0.7.5 + - streamlit>=0.73.1 + - einops==0.3.0 + - torch-fidelity==0.3.0 + - transformers==4.19.2 + - torchmetrics==0.7.0 + - kornia==0.6 + - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers + - -e git+https://github.com/openai/CLIP.git@main#egg=clip + - -e . diff --git a/examples/dog_beach__edit__003.png b/examples/dog_beach__edit__003.png new file mode 100644 index 0000000000000000000000000000000000000000..9abc7919299eb5359bc26c0df3d88f9c139db893 Binary files /dev/null and b/examples/dog_beach__edit__003.png differ diff --git a/examples/dog_beach_og.png b/examples/dog_beach_og.png new file mode 100644 index 0000000000000000000000000000000000000000..3e46857c5e1b2afaa015d494d7876efe7b3aae89 Binary files /dev/null and b/examples/dog_beach_og.png differ diff --git a/examples/fox_drinking__edit__01.png b/examples/fox_drinking__edit__01.png new file mode 100644 index 0000000000000000000000000000000000000000..1a739811880cbc5339f4f3e48e52a1d60e2b2c55 Binary files /dev/null and b/examples/fox_drinking__edit__01.png differ diff --git a/examples/fox_drinking__edit__02.png b/examples/fox_drinking__edit__02.png new file mode 100644 index 0000000000000000000000000000000000000000..2c90cc5554999884f4803d0261367e74a1d2d23d Binary files /dev/null and b/examples/fox_drinking__edit__02.png differ diff --git a/examples/fox_drinking_og.png b/examples/fox_drinking_og.png new file mode 100644 index 0000000000000000000000000000000000000000..bb5b3153205afa31cfda23516a285fe6f0a8f3ea Binary files /dev/null and b/examples/fox_drinking_og.png differ diff --git a/examples/kingfisher__edit__001.png b/examples/kingfisher__edit__001.png new file mode 100644 index 0000000000000000000000000000000000000000..044c868d959f79fb42f0768b403e6a1fd53ef641 Binary files /dev/null and b/examples/kingfisher__edit__001.png differ diff --git a/examples/kingfisher_og.png b/examples/kingfisher_og.png new file mode 100644 index 0000000000000000000000000000000000000000..2db8a578fe7e1e49eca28303bf832d3d1736fd10 Binary files /dev/null and b/examples/kingfisher_og.png differ diff --git a/examples/log.csv b/examples/log.csv new file mode 100644 index 0000000000000000000000000000000000000000..658b02fe2403be84319995d2511256c06fdf73c0 --- /dev/null +++ b/examples/log.csv @@ -0,0 +1,6 @@ +fox_drinking_og.png,fox_drinking__edit__01.png +palm_tree_og.png,palm_tree__edit__01.png +kingfisher_og.png,kingfisher__edit__001.png +pipes_og.png,pipes__edit__01.png +dog_beach_og.png,dog_beach__edit__003.png +fox_drinking_og.png,fox_drinking__edit__02.png \ No newline at end of file diff --git a/examples/palm_tree__edit__01.png b/examples/palm_tree__edit__01.png new file mode 100644 index 0000000000000000000000000000000000000000..5c6bf8b6424af1f69451a53cb70f7c22a680f9c2 Binary files /dev/null and b/examples/palm_tree__edit__01.png differ diff --git a/examples/palm_tree_og.png b/examples/palm_tree_og.png new file mode 100644 index 0000000000000000000000000000000000000000..b529715dcbeabd4e5abe389351bb6814855add23 Binary files /dev/null and b/examples/palm_tree_og.png differ diff --git a/examples/pipes__edit__01.png b/examples/pipes__edit__01.png new file mode 100644 index 0000000000000000000000000000000000000000..c5bf8f2124afc210297387970ee6cc67fbba386d Binary files /dev/null and b/examples/pipes__edit__01.png differ diff --git a/examples/pipes_og.png b/examples/pipes_og.png new file mode 100644 index 0000000000000000000000000000000000000000..9a2b42e1c85fcb37cce0e2908b4ab9946632c722 Binary files /dev/null and b/examples/pipes_og.png differ diff --git a/ku.py b/ku.py deleted file mode 100644 index e76116c8b0f0b843673601f8bedf77e6a0434ad0..0000000000000000000000000000000000000000 --- a/ku.py +++ /dev/null @@ -1 +0,0 @@ -jsj diff --git a/ldm/data/__init__.py b/ldm/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/data/collage_dataset.py b/ldm/data/collage_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5ec42a5cb22666ce573c33e2a942faea2d6997ab --- /dev/null +++ b/ldm/data/collage_dataset.py @@ -0,0 +1,230 @@ +# Copyright 2024 Adobe. All rights reserved. + +import numpy as np +import torch +import matplotlib.pyplot as plt +import torchvision.transforms.functional as F +import glob +import torchvision +from PIL import Image +import time +import os +import tqdm +from torch.utils.data import Dataset +import pathlib +import cv2 +from PIL import Image +import os +import json +import albumentations as A + +def get_tensor(normalize=True, toTensor=True): + transform_list = [] + if toTensor: + transform_list += [torchvision.transforms.ToTensor()] + + if normalize: + # transform_list += [torchvision.transforms.Normalize((0.0, 0.0, 0.0), + # (10.0, 10.0, 10.0))] + transform_list += [torchvision.transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + return torchvision.transforms.Compose(transform_list) + +def get_tensor_clip(normalize=True, toTensor=True): + transform_list = [torchvision.transforms.Resize((224,224))] + if toTensor: + transform_list += [torchvision.transforms.ToTensor()] + + if normalize: + transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711))] + return torchvision.transforms.Compose(transform_list) + +def get_tensor_dino(normalize=True, toTensor=True): + transform_list = [torchvision.transforms.Resize((224,224))] + if toTensor: + transform_list += [torchvision.transforms.ToTensor()] + + if normalize: + transform_list += [lambda x: 255.0 * x[:3], + torchvision.transforms.Normalize( + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + )] + return torchvision.transforms.Compose(transform_list) + +def crawl_folders(folder_path): + # glob crawl + all_files = [] + folders = glob.glob(f'{folder_path}/*') + + for folder in folders: + src_paths = glob.glob(f'{folder}/src_*png') + all_files.extend(src_paths) + return all_files + +def get_grid(size): + y = np.repeat(np.arange(size)[None, ...], size) + y = y.reshape(size, size) + x = y.transpose() + out = np.stack([y,x], -1) + return out + + +class CollageDataset(Dataset): + def __init__(self, split_files, image_size, embedding_type, warping_type, blur_warped=False): + self.size = image_size + # depends on the embedding type + if embedding_type == 'clip': + self.get_embedding_vector = get_tensor_clip() + elif embedding_type == 'dino': + self.get_embedding_vector = get_tensor_dino() + self.get_tensor = get_tensor() + self.resize = torchvision.transforms.Resize(size=(image_size, image_size)) + self.to_mask_tensor = get_tensor(normalize=False) + + self.src_paths = crawl_folders(split_files) + print('current split size', len(self.src_paths)) + print('for dir', split_files) + + assert warping_type in ['collage', 'flow', 'mix'] + self.warping_type = warping_type + + self.mask_threshold = 0.85 + + self.blur_t = torchvision.transforms.GaussianBlur(kernel_size=51, sigma=20.0) + self.blur_warped = blur_warped + + # self.save_folder = '/mnt/localssd/collage_out' + # os.makedirs(self.save_folder, exist_ok=True) + self.save_counter = 0 + self.save_subfolder = None + + def __len__(self): + return len(self.src_paths) + + + def __getitem__(self, idx, depth=0): + + if self.warping_type == 'mix': + # randomly sample + warping_type = np.random.choice(['collage', 'flow']) + else: + warping_type = self.warping_type + + src_path = self.src_paths[idx] + tgt_path = src_path.replace('src_', 'tgt_') + + if warping_type == 'collage': + warped_path = src_path.replace('src_', 'composite_') + mask_path = src_path.replace('src_', 'composite_mask_') + corresp_path = src_path.replace('src_', 'composite_grid_') + corresp_path = corresp_path.split('.')[0] + corresp_path += '.npy' + elif warping_type == 'flow': + warped_path = src_path.replace('src_', 'flow_warped_') + mask_path = src_path.replace('src_', 'flow_mask_') + corresp_path = src_path.replace('src_', 'flow_warped_grid_') + corresp_path = corresp_path.split('.')[0] + corresp_path += '.npy' + else: + raise ValueError + + # load reference image, warped image, and target GT image + reference_img = Image.open(src_path).convert('RGB') + gt_img = Image.open(tgt_path).convert('RGB') + warped_img = Image.open(warped_path).convert('RGB') + warping_mask = Image.open(mask_path).convert('RGB') + + # resize all + reference_img = self.resize(reference_img) + gt_img = self.resize(gt_img) + warped_img = self.resize(warped_img) + warping_mask = self.resize(warping_mask) + + + # NO CROPPING PLEASE. ALL INPUTS ARE 512X512 + # Random crop + # i, j, h, w = torchvision.transforms.RandomCrop.get_params( + # reference_img, output_size=(512, 512)) + + # reference_img = torchvision.transforms.functional.crop(reference_img, i, j, h, w) + # gt_img = torchvision.transforms.functional.crop(gt_img, i, j, h, w) + # warped_img = torchvision.transforms.functional.crop(warped_img, i, j, h, w) + # # TODO start using the warping mask + # warping_mask = torchvision.transforms.functional.crop(warping_mask, i, j, h, w) + + grid_transformed = torch.tensor(np.load(corresp_path)) + # grid_transformed = torchvision.transforms.functional.crop(grid_transformed, i, j, h, w) + + + + # reference_t = to_tensor(reference_img) + gt_t = self.get_tensor(gt_img) + warped_t = self.get_tensor(warped_img) + warping_mask_t = self.to_mask_tensor(warping_mask) + clean_reference_t = self.get_tensor(reference_img) + # compute error to generate mask + blur_t = torchvision.transforms.GaussianBlur(kernel_size=(11,11), sigma=5.0) + + reference_clip_img = self.get_embedding_vector(reference_img) + + mask = torch.ones_like(gt_t)[:1] + warping_mask_t = warping_mask_t[:1] + + good_region = torch.mean(warping_mask_t) + # print('good region', good_region) + # print('good region frac', good_region) + if good_region < 0.4 and depth < 3: + # example too hard, sample something else + # print('bad image, resampling..') + rand_idx = np.random.randint(len(self.src_paths)) + return self.__getitem__(rand_idx, depth+1) + + # if mask is too large then ignore + + # #gaussian inpainting now + missing_mask = warping_mask_t[0] < 0.5 + + + reference = (warped_t.clone() + 1) / 2.0 + ref_cv = torch.moveaxis(reference, 0, -1).cpu().numpy() + ref_cv = (ref_cv * 255).astype(np.uint8) + cv_mask = missing_mask.int().squeeze().cpu().numpy().astype(np.uint8) + kernel = np.ones((7,7)) + dilated_mask = cv2.dilate(cv_mask, kernel) + # cv_mask = np.stack([cv_mask]*3, axis=-1) + dst = cv2.inpaint(ref_cv,dilated_mask,5,cv2.INPAINT_NS) + + mask_resized = torchvision.transforms.functional.resize(warping_mask_t, (64,64)) + # print(mask_resized) + size=512 + grid_np = (get_grid(size) / size).astype(np.float16)# 512 x 512 x 2 + grid_t = torch.tensor(grid_np).moveaxis(-1, 0) # 512 x 512 x 2 + grid_resized = torchvision.transforms.functional.resize(grid_t, (64,64)).to(torch.float16) + changed_pixels = torch.logical_or((torch.abs(grid_resized - grid_transformed)[0] > 0.04) , (torch.abs(grid_resized - grid_transformed)[1] > 0.04)) + changed_pixels = changed_pixels.unsqueeze(0) + # changed_pixels = torch.logical_and(changed_pixels, (mask_resized >= 0.3)) + changed_pixels = changed_pixels.float() + + inpainted_warped = (torch.tensor(dst).moveaxis(-1, 0).float() / 255.0) * 2.0 - 1.0 + + if self.blur_warped: + inpainted_warped= self.blur_t(inpainted_warped) + + out = {"GT": gt_t,"inpaint_image": inpainted_warped,"inpaint_mask": warping_mask_t, "ref_imgs": reference_clip_img, "clean_reference": clean_reference_t, 'grid_transformed': grid_transformed, "changed_pixels": changed_pixels} + # out = {"GT": gt_t,"inpaint_image": inpainted_warped * 0.0,"inpaint_mask": torch.ones_like(warping_mask_t), "ref_imgs": reference_clip_img * 0.0, "clean_reference": gt_t, 'grid_transformed': grid_transformed, "changed_pixels": changed_pixels} + # out = {"GT": gt_t,"inpaint_image": inpainted_warped * 0.0,"inpaint_mask": warping_mask_t, "ref_imgs": reference_clip_img * 0.0, "clean_reference": clean_reference_t, 'grid_transformed': grid_transformed, "changed_pixels": changed_pixels} + + # out = {"GT": gt_t,"inpaint_image": warped_t,"inpaint_mask": warping_mask_t, "ref_imgs": reference_clip_img, "clean_reference": clean_reference_t, 'grid_transformed': grid_transformed, 'inpainted': inpainted_warped} + # out_half = {key: out[key].half() for key in out} + # if self.save_counter < 50: + # save_path = f'{self.save_folder}/output_{time.time()}.pt' + # torch.save(out, save_path) + # self.save_counter += 1 + + return out + + + + diff --git a/ldm/lr_scheduler.py b/ldm/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..89f2a0bc2cafe59830f958431005eadb72a9d8a4 --- /dev/null +++ b/ldm/lr_scheduler.py @@ -0,0 +1,111 @@ +# This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and +# Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example +# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. +# CreativeML Open RAIL-M +# +# ========================================================================================== +# +# Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. +# Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit +# LICENSE.md. +# +# ========================================================================================== + +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0. + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi)) + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n,**kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0. + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( + 1 + np.cos(t * np.pi)) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + self.last_f = f + return f + diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5623efd900dac843c2cac725efed31671225de --- /dev/null +++ b/ldm/models/autoencoder.py @@ -0,0 +1,456 @@ +# This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and +# Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example +# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. +# CreativeML Open RAIL-M +# +# ========================================================================================== +# +# Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. +# Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit +# LICENSE.md. +# +# ========================================================================================== + +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +from contextlib import contextmanager + +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer + +from ldm.modules.diffusionmodules.model import Encoder, Decoder +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + +from ldm.util import instantiate_from_config + + +class VQModel(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + batch_resize_range=None, + scheduler_config=None, + lr_g_factor=1.0, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + use_ema=False + ): + super().__init__() + self.embed_dim = embed_dim + self.n_embed = n_embed + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, + sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + self.batch_resize_range = batch_resize_range + if self.batch_resize_range is not None: + print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.scheduler_config = scheduler_config + self.lr_g_factor = lr_g_factor + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input, return_pred_indices=False): + quant, diff, (_,_,ind) = self.encode(input) + dec = self.decode(quant) + if return_pred_indices: + return dec, diff, ind + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + if self.batch_resize_range is not None: + lower_size = self.batch_resize_range[0] + upper_size = self.batch_resize_range[1] + if self.global_step <= 4: + # do the first few batches with max size to avoid later oom + new_resize = upper_size + else: + new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) + if new_resize != x.shape[2]: + x = F.interpolate(x, size=new_resize, mode="bicubic") + x = x.detach() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + # https://github.com/pytorch/pytorch/issues/37142 + # try not to fool the heuristics + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train", + predicted_indices=ind) + + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, suffix=""): + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] + self.log(f"val{suffix}/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"val{suffix}/aeloss", aeloss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + if version.parse(pl.__version__) >= version.parse('1.4.0'): + del log_dict_ae[f"val{suffix}/rec_loss"] + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr_d = self.learning_rate + lr_g = self.lr_g_factor*self.learning_rate + print("lr_d", lr_d) + print("lr_g", lr_g) + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr_g, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr_d, betas=(0.5, 0.9)) + + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + { + 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + ] + return [opt_ae, opt_disc], scheduler + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if only_inputs: + log["inputs"] = x + return log + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + if plot_ema: + with self.ema_scope(): + xrec_ema, _ = self(x) + if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + log["reconstructions_ema"] = xrec_ema + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class VQModelInterface(VQModel): + def __init__(self, embed_dim, *args, **kwargs): + super().__init__(embed_dim=embed_dim, *args, **kwargs) + self.embed_dim = embed_dim + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/ldm/models/diffusion/__init__.py b/ldm/models/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/models/diffusion/classifier.py b/ldm/models/diffusion/classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..846dd4b4f0916c2af4ad78c5ab4843dcebd15351 --- /dev/null +++ b/ldm/models/diffusion/classifier.py @@ -0,0 +1,280 @@ +# This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and +# Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example +# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. +# CreativeML Open RAIL-M +# +# ========================================================================================== +# +# Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. +# Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit +# LICENSE.md. +# +# ========================================================================================== + +import os +import torch +import pytorch_lightning as pl +from omegaconf import OmegaConf +from torch.nn import functional as F +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +from copy import deepcopy +from einops import rearrange +from glob import glob +from natsort import natsorted + +from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel +from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config + +__models__ = { + 'class_label': EncoderUNetModel, + 'segmentation': UNetModel +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class NoisyLatentImageClassifier(pl.LightningModule): + + def __init__(self, + diffusion_path, + num_classes, + ckpt_path=None, + pool='attention', + label_key=None, + diffusion_ckpt_path=None, + scheduler_config=None, + weight_decay=1.e-2, + log_steps=10, + monitor='val/loss', + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.num_classes = num_classes + # get latest config of diffusion model + diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] + self.diffusion_config = OmegaConf.load(diffusion_config).model + self.diffusion_config.params.ckpt_path = diffusion_ckpt_path + self.load_diffusion() + + self.monitor = monitor + self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 + self.log_time_interval = self.diffusion_model.num_timesteps // log_steps + self.log_steps = log_steps + + self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ + else self.diffusion_model.cond_stage_key + + assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' + + if self.label_key not in __models__: + raise NotImplementedError() + + self.load_classifier(ckpt_path, pool) + + self.scheduler_config = scheduler_config + self.use_scheduler = self.scheduler_config is not None + self.weight_decay = weight_decay + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def load_diffusion(self): + model = instantiate_from_config(self.diffusion_config) + self.diffusion_model = model.eval() + self.diffusion_model.train = disabled_train + for param in self.diffusion_model.parameters(): + param.requires_grad = False + + def load_classifier(self, ckpt_path, pool): + model_config = deepcopy(self.diffusion_config.params.unet_config.params) + model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels + model_config.out_channels = self.num_classes + if self.label_key == 'class_label': + model_config.pool = pool + + self.model = __models__[self.label_key](**model_config) + if ckpt_path is not None: + print('#####################################################################') + print(f'load from ckpt "{ckpt_path}"') + print('#####################################################################') + self.init_from_ckpt(ckpt_path) + + @torch.no_grad() + def get_x_noisy(self, x, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x)) + continuous_sqrt_alpha_cumprod = None + if self.diffusion_model.use_continuous_noise: + continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) + # todo: make sure t+1 is correct here + + return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, + continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) + + def forward(self, x_noisy, t, *args, **kwargs): + return self.model(x_noisy, t) + + @torch.no_grad() + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + @torch.no_grad() + def get_conditioning(self, batch, k=None): + if k is None: + k = self.label_key + assert k is not None, 'Needs to provide label key' + + targets = batch[k].to(self.device) + + if self.label_key == 'segmentation': + targets = rearrange(targets, 'b h w c -> b c h w') + for down in range(self.numd): + h, w = targets.shape[-2:] + targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') + + # targets = rearrange(targets,'b c h w -> b h w c') + + return targets + + def compute_top_k(self, logits, labels, k, reduction="mean"): + _, top_ks = torch.topk(logits, k, dim=1) + if reduction == "mean": + return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() + elif reduction == "none": + return (top_ks == labels[:, None]).float().sum(dim=-1) + + def on_train_epoch_start(self): + # save some memory + self.diffusion_model.model.to('cpu') + + @torch.no_grad() + def write_logs(self, loss, logits, targets): + log_prefix = 'train' if self.training else 'val' + log = {} + log[f"{log_prefix}/loss"] = loss.mean() + log[f"{log_prefix}/acc@1"] = self.compute_top_k( + logits, targets, k=1, reduction="mean" + ) + log[f"{log_prefix}/acc@5"] = self.compute_top_k( + logits, targets, k=5, reduction="mean" + ) + + self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) + self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) + self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) + + def shared_step(self, batch, t=None): + x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) + targets = self.get_conditioning(batch) + if targets.dim() == 4: + targets = targets.argmax(dim=1) + if t is None: + t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() + else: + t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() + x_noisy = self.get_x_noisy(x, t) + logits = self(x_noisy, t) + + loss = F.cross_entropy(logits, targets, reduction='none') + + self.write_logs(loss.detach(), logits.detach(), targets.detach()) + + loss = loss.mean() + return loss, logits, x_noisy, targets + + def training_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + return loss + + def reset_noise_accs(self): + self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in + range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} + + def on_validation_start(self): + self.reset_noise_accs() + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + + for t in self.noisy_acc: + _, logits, _, targets = self.shared_step(batch, t) + self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) + self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) + + return loss + + def configure_optimizers(self): + optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + + if self.use_scheduler: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [optimizer], scheduler + + return optimizer + + @torch.no_grad() + def log_images(self, batch, N=8, *args, **kwargs): + log = dict() + x = self.get_input(batch, self.diffusion_model.first_stage_key) + log['inputs'] = x + + y = self.get_conditioning(batch) + + if self.label_key == 'class_label': + y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['labels'] = y + + if ismap(y): + log['labels'] = self.diffusion_model.to_rgb(y) + + for step in range(self.log_steps): + current_time = step * self.log_time_interval + + _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) + + log[f'inputs@t{current_time}'] = x_noisy + + pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) + pred = rearrange(pred, 'b h w c -> b c h w') + + log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) + + for key in log: + log[key] = log[key][:N] + + return log diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..50b0d7db6e92f78c9e073c406e8226e617390a50 --- /dev/null +++ b/ldm/models/diffusion/ddim.py @@ -0,0 +1,296 @@ +# This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and +# Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example +# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. +# CreativeML Open RAIL-M +# +# ========================================================================================== +# +# Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. +# Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit +# LICENSE.md. +# +# ========================================================================================== + +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \ + extract_into_tensor + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True, steps=None): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose, steps=steps) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + z_ref=None, + ddim_discretize='uniform', + schedule_steps=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose, ddim_discretize=ddim_discretize, steps=schedule_steps) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + z_ref=z_ref, + **kwargs + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, x0_step=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, z_ref=None,**kwargs): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if x0_step is not None and i < x0_step: + assert x0 is not None + img = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + # img = img_orig * mask + (1. - mask) * img + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + z_ref=z_ref, + unconditional_conditioning=unconditional_conditioning,**kwargs) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, z_ref=None, drop_latent_guidance=1.0,**kwargs): + b, *_, device = *x.shape, x.device + if 'test_model_kwargs' in kwargs: + kwargs=kwargs['test_model_kwargs'] + if f'inpaint_mask_{index}' in kwargs: + x = torch.cat([x, kwargs['inpaint_image'], kwargs[f'inpaint_mask_{index}']],dim=1) + print('using proxy mask', index) + else: + x = torch.cat([x, kwargs['inpaint_image'], kwargs[f'inpaint_mask']],dim=1) + if 'changed_pixels' in kwargs: + x = torch.cat([x, kwargs['changed_pixels']],dim=1) + elif 'rest' in kwargs: + x = torch.cat((x, kwargs['rest']), dim=1) + else: + raise Exception("kwargs must contain either 'test_model_kwargs' or 'rest' key") + + # maybe should assert not both of these are true + # print('index', index) + if isinstance(drop_latent_guidance, list): + cur_drop_latent_guidance = drop_latent_guidance[index] + else: + cur_drop_latent_guidance = drop_latent_guidance + # print('cur drop guidance', cur_drop_latent_guidance) + + if (unconditional_conditioning is None or unconditional_guidance_scale == 1.) and cur_drop_latent_guidance == 1.: + e_t = self.model.apply_model(x, t, c, z_ref=z_ref) + elif cur_drop_latent_guidance != 1.: + assert (unconditional_conditioning is None or unconditional_guidance_scale == 1.) + x_dropped = x.clone() + # print('x dropped shape', x_dropped.shape) + x_dropped[:,4:9] *= 0.0 + x_in = torch.cat([x_dropped, x]) + t_in = torch.cat([t] * 2) + z_ref_in = torch.cat([z_ref] * 2) + c_in = torch.cat([c] * 2) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, z_ref=z_ref_in).chunk(2) + e_t = e_t_uncond + cur_drop_latent_guidance * (e_t - e_t_uncond) + + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + z_ref_in = torch.cat([z_ref] * 2) + # print('uncond shape', unconditional_conditioning.shape, 'c shape', c.shape) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, z_ref=z_ref_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + if x.shape[1]!=4: + pred_x0 = (x[:,:4,:,:] - sqrt_one_minus_at * e_t) / a_t.sqrt() + else: + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(dir_xt.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + + @torch.no_grad() + def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + use_original_steps=False): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + return x_dec \ No newline at end of file diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py new file mode 100644 index 0000000000000000000000000000000000000000..e69ddaef77811917b8dc164fe11e9e6af7dcf07e --- /dev/null +++ b/ldm/models/diffusion/ddpm.py @@ -0,0 +1,1877 @@ +# This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and +# Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example +# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. +# CreativeML Open RAIL-M +# +# ========================================================================================== +# +# Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. +# Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit +# LICENSE.md. +# +# ========================================================================================== + +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +import torch +import torch.nn as nn +import torchvision +import numpy as np +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager +from functools import partial +from tqdm import tqdm +from torchvision.utils import make_grid +# from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.rank_zero import rank_zero_only +from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler +from torchvision.transforms import Resize +import math +import time +import random +from torch.autograd import Variable +import copy +import os + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + u_cond_percent=0, + dropping_warped_latent_prob=0., + remove_warped_latent=False, + gt_flag='GT', + sd_edit_step=850 + ): + super().__init__() + assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size + self.channels = channels + self.u_cond_percent=u_cond_percent + self.use_positional_encodings = use_positional_encodings + self.gt_flag = gt_flag + self.sd_edit_step = sd_edit_step + + self.remove_warped_latent = remove_warped_latent + self.dropping_warped_latent_prob = dropping_warped_latent_prob + + if dropping_warped_latent_prob > 0.0: + assert not self.remove_warped_latent + + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.model = DiffusionWrapper(unet_config, conditioning_key, ddpm_parent=self, + sqrt_alphas_cumprod=self.sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=self.sqrt_one_minus_alphas_cumprod) + count_params(self.model, verbose=True) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + + # rescale beta + rescale_beta = True + if rescale_beta: + betas = rescale_zero_terminal_snr(torch.tensor(betas)).numpy() + + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + else: + raise NotImplementedError("mu not supported") + # pr_odo how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + else: + raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + if k == "inpaint": + x = batch[self.gt_flag] + mask = batch['inpaint_mask'] + inpaint = batch['inpaint_image'] + reference = batch['ref_imgs'] + clean_reference = batch['clean_reference'] + grid_transformed = batch['grid_transformed'] + changed_pixels = batch['changed_pixels'] + else: + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + # x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + mask = mask.to(memory_format=torch.contiguous_format).float() + inpaint = inpaint.to(memory_format=torch.contiguous_format).float() + reference = reference.to(memory_format=torch.contiguous_format).float() + clean_reference = clean_reference.to(memory_format=torch.contiguous_format).float() + grid_transformed = grid_transformed.to(memory_format=torch.contiguous_format).float() + return x,inpaint,mask,reference, clean_reference, grid_transformed, changed_pixels + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, + prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + context_embedding_dim=1024, # dim used for clip image encoder + *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__': + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.learnable_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=True) + self.proj_out=nn.Linear(context_embedding_dim, 768) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip(L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # pr_odo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, + cond_key=None, return_original_cond=False, bs=None,get_mask=False,get_reference=False,get_inpaint=False, get_clean_ref=False, get_ref_rec=False, + get_changed_pixels=False): + + x,inpaint,mask,reference, clean_reference, grid_transformed, changed_pixels = super().get_input(batch, k) + if bs is not None: + x = x[:bs] + inpaint = inpaint[:bs] + mask = mask[:bs] + reference = reference[:bs] + clean_reference = clean_reference[:bs] + grid_transformed = grid_transformed[:bs] + changed_pixels = changed_pixels[:bs] + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + encoder_posterior_inpaint = self.encode_first_stage(inpaint) + z_inpaint = self.get_first_stage_encoding(encoder_posterior_inpaint).detach() + + encoder_posterior_inpaint = self.encode_first_stage(clean_reference) + z_reference = self.get_first_stage_encoding(encoder_posterior_inpaint).detach() + # breakpoint() + mask_resize = Resize([z.shape[-1],z.shape[-1]])(mask) + grid_resized = Resize([z.shape[-1],z.shape[-1]])(grid_transformed) + z_new = torch.cat((z,z_inpaint,mask_resize, grid_resized),dim=1) + # z_new = torch.cat((z,z_inpaint,mask_resize, changed_pixels, grid_resized),dim=1) + # z_new = torch.cat((z,z_inpaint,mask_resize, grid_resized),dim=1) + + if self.model.conditioning_key is not None: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ['txt','caption', 'coordinates_bbox']: + xc = batch[cond_key] + elif cond_key == 'image': + xc = reference + elif cond_key == 'class_label': + xc = batch + else: + xc = super().get_input(batch, cond_key).to(self.device) + else: + xc = x + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, dict) or isinstance(xc, list): + # import pudb; pudb.set_trace() + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc.to(self.device)) + c = self.proj_out(c) + c = c.float() + else: + c = xc + if bs is not None: + c = c[:bs] + + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + ckey = __conditioning_keys__[self.model.conditioning_key] + c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} + + else: + c = None + xc = None + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + c = {'pos_x': pos_x, 'pos_y': pos_y} + + # embed reference latent into cond + # c = [c, z_reference] + out = [z_new, c, z_reference] + if return_first_stage_outputs: + if self.first_stage_key=='inpaint': + xrec = self.decode_first_stage(z[:,:4,:,:]) + else: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + if get_mask: + out.append(mask) + if get_reference: + out.append(reference) + if get_inpaint: + out.append(inpaint) + if get_clean_ref: + out.append(clean_reference) + if get_ref_rec: + ref_rec = self.decode_first_stage(z_reference) + out.append(ref_rec) + if get_changed_pixels: + out.append(changed_pixels) + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + if self.first_stage_key=='inpaint': + return self.first_stage_model.decode(z[:,:4,:,:]) + else: + return self.first_stage_model.decode(z) + + # same as above but without decorator + def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + df = self.split_input_params["vqf"] + self.split_input_params['original_image_size'] = x.shape[-2:] + bs, nc, h, w = x.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) + z = unfold(x) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) + o = o * weighting + + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization + return decoded + + else: + return self.first_stage_model.encode(x) + else: + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c, z_reference = self.get_input(batch, self.first_stage_key) + loss = self(x, c, z_reference) + return loss + + def forward(self, x, c, z_reference, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + self.u_cond_prop=random.uniform(0, 1) + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + c = self.proj_out(c) + + if self.shorten_cond_schedule: # pr_odo: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + + if self.u_cond_prop (l b) n') + # adapted_cond = self.get_learned_conditioning(adapted_cond) + # adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) + + # cond_list = [{'c_crossattn': [e]} for e in adapted_cond] + + # else: + # cond_list = [cond for i in range(z.shape[-1])] # pr_odo make this more efficient + + # # apply model by loop over crops + # output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] + # assert not isinstance(output_list[0], + # tuple) # pr_odo cant deal with multiple model outputs check this never happens + + # o = torch.stack(output_list, axis=-1) + # o = o * weighting + # # Reverse reshape to img shape + # o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # # stitch crops together + # x_recon = fold(o) / normalization + + else: + # TODO address passing ref + zeroed_out_warped_latent = x_noisy.clone() + if self.remove_warped_latent: + zeroed_out_warped_latent[:,4:8] *= 0.0 + x_recon = self.model(zeroed_out_warped_latent, t, z_ref=z_ref, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, z_ref, noise=None): + if self.first_stage_key == 'inpaint': + # x_start=x_start[:,:4,:,:] + latents = x_start[:,:4,:,:] + latents_warped = x_start[:,4:8,:,:] + noise = default(noise, lambda: torch.randn_like(x_start[:,:4,:,:])) + # offset noise + # noise += 0.05 * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + # TODO address the reference latent + # warped_mask = t > self.sd_edit_step + + x_noisy = self.q_sample(x_start=latents, t=t, noise=noise) + # warped_noisy = self.q_sample(x_start=latents_warped, t=t, noise=noise) + # x_noisy[warped_mask] = warped_noisy[warped_mask] + + # TODO add here + remove_latent_prob=random.uniform(0, 1) + + if remove_latent_prob < self.dropping_warped_latent_prob: + modified_x_start = x_start.clone() + # dropping warped latent and mask + modified_x_start[:, 4:9] *= 0.0 + + # print('using modified x start') + x_noisy = torch.cat((x_noisy,modified_x_start[:,4:,:,:]),dim=1) + else: + x_noisy = torch.cat((x_noisy,x_start[:,4:,:,:]),dim=1) + else: + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond, z_ref) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + self.logvar = self.logvar.to(self.device) + logvar_t = self.logvar[t].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, + return_x0=False, score_corrector=None, corrector_kwargs=None, z_ref=None): + t_in = t + #TODO pass reference + model_out = self.apply_model(x, t_in, c, z_ref=z_ref, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, z_ref=None, clip_denoised=False, repeat_noise=False, + return_codebook_ids=False, quantize_denoised=False, return_x0=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, z_ref=z_ref, clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, cond, shape, z_ref=None, verbose=True, callback=None, quantize_denoised=False, + img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + log_every_t=None): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample(img, cond, ts, z_ref=z_ref, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, return_x0=True, + temperature=temperature[i], noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: callback(i) + if img_callback: img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop(self, cond, shape, z_ref=None, return_intermediates=False, + x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, start_T=None, + log_every_t=None): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample(img, cond, ts, z_ref=z_ref, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=True, timesteps=None, quantize_denoised=False, + mask=None, x0=None, shape=None,**kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, x_T=x_T, + verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, + mask=mask, x0=x0) + + @torch.no_grad() + def sample_log(self,cond,batch_size,ddim, ddim_steps, z_ref=None, full_z=None,**kwargs): + + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + z_inpaint = full_z[:,4:8] + step=1 + + + samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size, + shape,cond, z_ref=z_ref,verbose=False, x0=z_inpaint, + x0_step=step,**kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, + return_intermediates=True,**kwargs) + + return samples, intermediates + + + @torch.no_grad() + def log_images(self, batch, N=4, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False, + plot_diffusion_rows=True, **kwargs): + + use_ddim = ddim_steps is not None + + log = dict() + + z, c, z_ref, x, xrec, xc, mask, reference, inpaint_img, clean_ref, ref_rec, changed_pixels = self.get_input(batch, self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + get_mask=True, + get_reference=True, + get_inpaint=True, + bs=N, + get_clean_ref=True, + get_ref_rec=True, + get_changed_pixels=True) + + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + log["mask"]=mask + log['changed_pixels'] = changed_pixels + log["warped"]=inpaint_img + log["original"] = clean_ref + log["ref_rec"] = ref_rec + # log["reference"]=reference + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption","txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key]) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + if self.first_stage_key=='inpaint': + samples, z_denoise_row = self.sample_log(cond=c, z_ref=z_ref,batch_size=N,ddim=use_ddim, full_z=z, + ddim_steps=ddim_steps,eta=ddim_eta,rest=z[:,4:,:,:]) + else: + samples, z_denoise_row = self.sample_log(cond=c, z_ref=z_ref,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with self.ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log(cond=c, z_ref=z_ref, batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta, + quantize_denoised=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with self.ema_scope("Plotting Inpaint"): + + samples, _ = self.sample_log(cond=c, z_ref=z_ref,batch_size=N,ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N,:4], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + with self.ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, z_ref=z_ref, batch_size=N, ddim=use_ddim,eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with self.ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + z_ref=z_ref, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + + + + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + # need to add final_ln.parameters() TODO + params = params + list(self.cond_stage_model.final_ln.parameters())+list(self.cond_stage_model.mapper.parameters())+list(self.proj_out.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + params.append(self.learnable_vector) + opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key, sqrt_alphas_cumprod=None, sqrt_one_minus_alphas_cumprod=None, ddpm_parent=None): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'crossref', 'rewarp', 'rewarp_grid'] + # self.save_folder = '/mnt/localssd/collage_latents_lovely_new_data' + # self.save_counter = 0 + # self.save_subfolder = None + + # os.makedirs(self.save_folder, exist_ok=True) + self.sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod + self.sqrt_alphas_cumprod = sqrt_alphas_cumprod + self.og_grid = None + self.transformed_grid = None + if self.conditioning_key == 'crossref' or 'rewarp' in self.conditioning_key: + self.reference_model = copy.deepcopy(self.diffusion_model) + + + def get_grid(self, size, batch_size): + # raise ValueError TODO Fix + y = np.repeat(np.arange(size)[None, ...], size) + y = y.reshape(size, size) + x = y.transpose() + out = np.stack([y,x], 0) + out = torch.tensor(out) + out = out.unsqueeze(0) + out = out.repeat(batch_size, 1, 1, 1) + return out + + def compute_correspondences(self, grid_transformed, masks, original_size=512, add_grids=False): + # create the correspondence map for all the needed sizes + corresp_indices = {} + batch_size = grid_transformed.shape[0] + + if self.og_grid is None: + grid_og = self.get_grid(original_size, batch_size).to(grid_transformed.device) / float(original_size) + else: + grid_og = self.og_grid + + + for d in [8, 16, 32, 64]: + resized_grid_1 = torchvision.transforms.functional.resize(grid_og, size=(d,d)) + resized_grid_2 = torchvision.transforms.functional.resize(grid_transformed, size=(d,d)) + # the mask is at 64x64. 1 means exist in image. 0 is missing (needs inpainting) + resized_mask = torchvision.transforms.functional.resize(masks, size=(d,d)) + + missing_mask = resized_mask.squeeze(1) < 0.7 #torch.sum(resized_grid_2, dim=1) < 0.1 + + src_grid = resized_grid_1.permute(0,2,3,1) # B x 2 x d x d + guide_grid = resized_grid_2.permute(0,2,3,1) + + src1_flat = src_grid.reshape(batch_size, d**2, 2) + src2_flat = guide_grid.reshape(batch_size, d**2, 2) + missing_flat = missing_mask.reshape(batch_size, d**2) + + torch_dist = torch.cdist(src2_flat.float(), src1_flat.float()) + # print('torch dist shape for d', d, torch_dist.shape) + + # missing_masks[d] = missing_flat + min_indices = torch.argmin(torch_dist, dim=-1) + # min_indices.requires_grad = False + # missing_flat.requires_grad = False + if add_grids: + corresp_indices[d] = (min_indices, missing_flat, resized_grid_1, resized_grid_2) + else: + corresp_indices[d] = (min_indices, missing_flat) + return corresp_indices #, missing_masks + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(x_start.device) + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(x_start.device) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, z_ref = None): + num_ch = x.shape[1] + # print(num_ch) + if num_ch >= 11: + self.transformed_grid = x[:, -2:] + x = x[:, :-2] + # else: + # grid_transformed = None + + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc) + + # self.save_subfolder = f'{self.save_folder}/saved_{time.time()}' + # os.makedirs(self.save_subfolder, exist_ok=True) + # # just for saving purposes + # assert z_ref is not None + # noisy_z_ref = self.q_sample(z_ref, t) + # # z_new = torch.cat((z,z_inpaint,mask_resize),dim=1) + + # mask = x[:, -1:] + # z_ref_concat = torch.cat([noisy_z_ref, z_ref, mask], dim=1) + + # correspondeces = self.compute_correspondences(self.transformed_grid, mask, original_size=512, add_grids=True) + + # if self.save_counter < 50: + # torch.save(x.cpu(), f'{self.save_subfolder}/z_collage_concat.pt' ) + # torch.save(z_ref_concat.cpu(), f'{self.save_subfolder}/z_ref_concat.pt') + # torch.save(correspondeces, f'{self.save_subfolder}/corresps.pt') + # self.save_counter += 1 + + + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + # elif self.conditioning_key == 'crossref': + # cc = torch.cat(c_crossattn, 1) + # # qsample z_ref by t to add noise + # # so have noisy z_ref + z_ref + mask + # # compute contexts + # assert z_ref is not None + # noisy_z_ref = self.q_sample(z_ref, t) + # # z_new = torch.cat((z,z_inpaint,mask_resize),dim=1) + # mask = x[:, -1:] + # z_ref_concat = torch.cat([noisy_z_ref, z_ref, mask], dim=1) + + + # # compute contexts + # _, contexts = self.reference_model(z_ref_concat, t, context=cc, get_contexts=True) + + # # input diffusion model with contexts + # out = self.diffusion_model(x, t, context=cc, passed_contexts=contexts) + + elif self.conditioning_key == 'rewarp' or self.conditioning_key == 'crossref': # also include the crossref for now + cc = torch.cat(c_crossattn, 1) + # qsample z_ref by t to add noise + # so have noisy z_ref + z_ref + mask + # compute contexts + if self.conditioning_key == 'crossref': + raise ValueError('currently not implemented properly. please fix attention') + assert z_ref is not None + noisy_z_ref = self.q_sample(z_ref, t) + # z_new = torch.cat((z,z_inpaint,mask_resize),dim=1) + + # mask = x[:, -2:-1] # mask and new regions + # changed_pixels = x[:, -1:] + # z_ref_concat = torch.cat([noisy_z_ref, z_ref, mask, changed_pixels], dim=1) + mask = x[:, -1:] # mask and new regions + z_ref_concat = torch.cat([noisy_z_ref, z_ref, mask], dim=1) + + + init_corresp_time = time.time() + correspondeces = self.compute_correspondences(self.transformed_grid, mask, original_size=512) ## TODO make input dependent + final_corresp_time = time.time() + + # compute contexts + _, contexts = self.reference_model(z_ref_concat, t, context=cc, get_contexts=True) + # input diffusion model with contexts + out = self.diffusion_model(x, t, context=cc, passed_contexts=contexts, corresp=correspondeces) + + elif self.conditioning_key == 'rewarp_grid': + grid_og = self.get_grid(64, batch_size=x.shape[0]).to(x.device) / 64.0 + cc = torch.cat(c_crossattn, 1) + # qsample z_ref by t to add noise + # so have noisy z_ref + z_ref + mask + # compute contexts + + assert z_ref is not None + noisy_z_ref = self.q_sample(z_ref, t) + # z_new = torch.cat((z,z_inpaint,mask_resize),dim=1) + + # mask = x[:, -2:-1] # mask and new regions + # changed_pixels = x[:, -1:] + # z_ref_concat = torch.cat([noisy_z_ref, z_ref, mask, changed_pixels], dim=1) + mask = x[:, -1:] # mask and new regions + z_ref_concat = torch.cat([noisy_z_ref, z_ref, mask, grid_og], dim=1) + x = torch.cat([x, grid_og], dim=1) + + correspondeces = self.compute_correspondences(self.transformed_grid, mask, original_size=512) ## TODO make input dependent + + # compute contexts + _, contexts = self.reference_model(z_ref_concat, t, context=cc, get_contexts=True) + # input diffusion model with contexts + out = self.diffusion_model(x, t, context=cc, passed_contexts=contexts, corresp=correspondeces) + + else: + raise NotImplementedError() + + return out + + +class Layout2ImgDiffusion(LatentDiffusion): + # pr_odo: move all layout-specific hacks to this class + def __init__(self, cond_stage_key, *args, **kwargs): + assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' + super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + + def log_images(self, batch, N=8, *args, **kwargs): + logs = super().log_images(batch=batch, N=N, *args, **kwargs) + + key = 'train' if self.training else 'validation' + dset = self.trainer.datamodule.datasets[key] + mapper = dset.conditional_builders[self.cond_stage_key] + + bbox_imgs = [] + map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno)) + for tknzd_bbox in batch[self.cond_stage_key][:N]: + bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256)) + bbox_imgs.append(bboximg) + + cond_img = torch.stack(bbox_imgs, dim=0) + logs['bbox_image'] = cond_img + return logs + +class LatentInpaintDiffusion(LatentDiffusion): + def __init__( + self, + concat_keys=("mask", "masked_image"), + masked_image_key="masked_image", + finetune_keys=None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.masked_image_key = masked_image_key + assert self.masked_image_key in concat_keys + self.concat_keys = concat_keys + + + @torch.no_grad() + def get_input( + self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False + ): + # note: restricted to non-trainable encoders currently + assert ( + not self.cond_stage_trainable + ), "trainable cond stages not yet supported for inpainting" + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) + + assert exists(self.concat_keys) + c_cat = list() + for ck in self.concat_keys: + cc = ( + rearrange(batch[ck], "b h w c -> b c h w") + .to(memory_format=torch.contiguous_format) + .float() + ) + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + bchw = z.shape + if ck != self.masked_image_key: + cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) + else: + cc = self.get_first_stage_encoding(self.encode_first_stage(cc)) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py new file mode 100644 index 0000000000000000000000000000000000000000..739d396939cacbf50eace13d312368614fba4e71 --- /dev/null +++ b/ldm/models/diffusion/plms.py @@ -0,0 +1,251 @@ +# This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and +# Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example +# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. +# CreativeML Open RAIL-M +# +# ========================================================================================== +# +# Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. +# Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit +# LICENSE.md. +# +# ========================================================================================== + +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples, intermediates = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + **kwargs + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,**kwargs): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, t_next=ts_next,**kwargs) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,**kwargs): + b, *_, device = *x.shape, x.device + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + kwargs=kwargs['test_model_kwargs'] + x_new=torch.cat([x,kwargs['inpaint_image'],kwargs['inpaint_mask']],dim=1) + e_t = get_model_output(x_new, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + x_prev_new=torch.cat([x_prev,kwargs['inpaint_image'],kwargs['inpaint_mask']],dim=1) + e_t_next = get_model_output(x_prev_new, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..d9064b61ec11aeaad55bf08f22af49d72167a465 --- /dev/null +++ b/ldm/modules/attention.py @@ -0,0 +1,372 @@ +# This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and +# Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example +# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. +# CreativeML Open RAIL-M +# +# ========================================================================================== +# +# Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. +# Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit +# LICENSE.md. +# +# ========================================================================================== + +from inspect import isfunction +import math +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat +import glob + +from ldm.modules.diffusionmodules.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., only_crossref=False): + super().__init__() + inner_dim = dim_head * heads + # forcing attention to only attend on vectors of same size + # breaking the image2text attention + context_dim = default(context_dim, query_dim) + + # print('creating cross attention. Query dim', query_dim, ' context dim', context_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + self.only_crossref = only_crossref + if only_crossref: + self.merge_attentions = zero_module(nn.Conv2d(self.heads * 2, + self.heads, + kernel_size=1, + stride=1, + padding=0)) + else: + self.merge_attentions = zero_module(nn.Conv2d(self.heads * 3, + self.heads, + kernel_size=1, + stride=1, + padding=0)) + + + self.merge_attentions_missing = zero_module(nn.Conv2d(self.heads * 2, + self.heads, + kernel_size=1, + stride=1, + padding=0)) + + + def forward(self, x, context=None, mask=None, passed_qkv=None, masks=None, corresp=None, missing_region=None): + is_self_attention = context is None + + # if masks is not None: + # print(is_self_attention, masks.keys()) + + h = self.heads + + # if passed_qkv is not None: + # assert context is None + + # _,_,_,_, x_features = passed_qkv + # assert x_features is not None + + # # print('x shape', x.shape, 'x features', x_features.shape) + # # breakpoint() + # x = torch.concat([x, x_features], dim=1) + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + assert False + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + out = einsum('b i j, b j d -> b i d', attn, v) + inter_out = rearrange(out, '(b h) n d -> b h n d', h=h) + + combined_attention = inter_out + out = rearrange(combined_attention, 'b h n d -> b n (h d)', h=h) + + final_out = self.to_out(out) + + if is_self_attention: + return final_out, q, k, v, inter_out #TODO add attn out + else: + return final_out + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): + super().__init__() + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.attn3 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + # TODO add attn in + def forward(self, x, context=None, passed_qkv=None, masks=None, corresp=None): + if passed_qkv is None: + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + else: + q, k, v, attn, x_features = passed_qkv + d = int(np.sqrt(q.shape[1])) + current_mask = masks[d] + if corresp: + current_corresp, missing_region = corresp[d] + current_corresp = current_corresp.float() + missing_region = missing_region.float() + else: + raise ValueError('cannot have empty corresp') + current_corresp = None + missing_region = current_mask.float() + # breakpoint() + stuff = [q, k, v, attn, x_features, current_mask, current_corresp, missing_region] + for element in stuff: + assert element is not None + return checkpoint(self._forward, (x, context, q, k, v, attn, x_features, current_mask, current_corresp, missing_region), self.parameters(), self.checkpoint) + + # TODO add attn in + def _forward(self, x, context=None, q=None, k=None, v=None, attn=None, passed_x=None, masks=None, corresp=None, missing_region=None): + if q is not None: + passed_qkv = (q, k, v, attn, passed_x) + else: + passed_qkv = None + x_features = self.norm1(x) + attended_x, q, k, v, attn = self.attn1(x_features, passed_qkv=passed_qkv, masks=masks, corresp=corresp, missing_region=missing_region) + x = attended_x + x + # killing CLIP features + + if passed_x is not None: + normed_x = self.norm2(x) + attn_out = self.attn3(normed_x, context=passed_x) + x = attn_out + x + # then use y + x + # print('y shape', y.shape, ' x shape', x.shape) + + x = self.ff(self.norm3(x)) + x + return x, q, k, v, attn, x_features + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + # print('creating spatial transformer') + # print('in channels', in_channels, 'inner dim', inner_dim) + + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + for d in range(depth)] + ) + + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + # TODO add attn in and corresp + def forward(self, x, context=None, passed_qkv=None, masks=None, corresp=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + # print('spatial transformer x shape given', x.shape) + # if context is not None: + # print('also context was provided with shape ', context.shape) + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c') + + qkvs = [] + for block in self.transformer_blocks: + x, q, k, v, attn, x_features = block(x, context=context, passed_qkv=passed_qkv, masks=masks, corresp=corresp) + qkv = (q,k,v,attn, x_features) + qkvs.append(qkv) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + x = self.proj_out(x) + return x + x_in, qkvs \ No newline at end of file diff --git a/ldm/modules/diffusionmodules/__init__.py b/ldm/modules/diffusionmodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..eb31f57a9e802b85d41be7b22693c3cf80a12dbf --- /dev/null +++ b/ldm/modules/diffusionmodules/model.py @@ -0,0 +1,848 @@ +# This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and +# Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example +# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. +# CreativeML Open RAIL-M +# +# ========================================================================================== +# +# Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. +# Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit +# LICENSE.md. +# +# ========================================================================================== + +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from ldm.util import instantiate_from_config +from ldm.modules.attention import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x + +class FirstStagePostProcessor(nn.Module): + + def __init__(self, ch_mult:list, in_channels, + pretrained_model:nn.Module=None, + reshape=False, + n_channels=None, + dropout=0., + pretrained_config=None): + super().__init__() + if pretrained_config is None: + assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) + self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, + stride=1,padding=1) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + + @torch.no_grad() + def encode_with_pretrained(self,x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self,x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model,self.downsampler): + z = submodel(z,temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z,'b c h w -> b (h w) c') + return z + diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 0000000000000000000000000000000000000000..f3e45f539649b6d6ee0b4b492fa8b11060db3e70 --- /dev/null +++ b/ldm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,1225 @@ +# This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and +# Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example +# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. +# CreativeML Open RAIL-M +# +# ========================================================================================== +# +# Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. +# Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit +# LICENSE.md. +# +# ========================================================================================== + +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable +from collections import deque + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F +import glob +import os + +import torchvision + +from ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ldm.modules.attention import SpatialTransformer + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None, passed_kqv=None, kqv_idx=None, masks=None, corresp=None): + attention_vals = [] + # print('processing a layer') + # print('idx', kqv_idx) + for layer in self: + # print('processing a layer', layer.__class__.__name__) + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + if passed_kqv is not None: + assert kqv_idx is not None + passed_item = passed_kqv[kqv_idx] + # print('pre passed item len', len(passed_item)) + if len(passed_item) == 1: + passed_item = passed_item[0][0] + # print('success passed item', len(passed_item)) + else: + passed_item = None + x, kqv = layer(x, context, passed_item, masks=masks, corresp=corresp) + attention_vals.append(kqv) + else: + x = layer(x) + # print('length of attn vals', len(attention_vals)) + return x, attention_vals + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class My_ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, 4, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, 4, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + add_conv_in_front_of_unet=False, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + self.add_conv_in_front_of_unet=add_conv_in_front_of_unet + + + # save contexts + self.save_contexts = False + self.use_contexts = False + self.contexts = deque([]) + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + + if self.add_conv_in_front_of_unet: + self.add_resbolck = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, 9, model_channels, 3, padding=1) + ) + ] + ) + + add_layers = [ + My_ResBlock( + model_channels, + time_embed_dim, + dropout, + out_channels=model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + + self.add_resbolck.append(TimestepEmbedSequential(*add_layers)) + + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None, get_contexts=False, passed_contexts=None, corresp=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + ds = [8, 16, 32, 64] + + # cur_step = len(glob.glob('/dev/shm/dumpster/steps/*')) + # os.makedirs(f'/dev/shm/dumpster/steps/{cur_step:04d}', exist_ok=False) + + og_mask = x[:, -1:] # Bx1x64x64 + batch_size = og_mask.shape[0] + masks = dict() + + for d in ds: + resized_mask = torchvision.transforms.functional.resize(og_mask, size=(d, d)) + + mask = resized_mask.reshape(batch_size, -1) + masks[d] = mask + + # if self.use_contexts: + # passed_contexts = self.contexts.popleft() + + all_kqvs = [] + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + + if self.add_conv_in_front_of_unet: + for module in self.add_resbolck: + h, kqv = module(h, emb, context, passed_contexts, len(all_kqvs), masks=masks, corresp=corresp) + all_kqvs.append(kqv) + + for module in self.input_blocks: + h, kqv = module(h, emb, context, passed_contexts, len(all_kqvs), masks=masks, corresp=corresp) + hs.append(h) + all_kqvs.append(kqv) + + h, kqv = self.middle_block(h, emb, context, passed_contexts, len(all_kqvs), masks=masks, corresp=corresp) + all_kqvs.append(kqv) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h, kqv = module(h, emb, context, passed_contexts, len(all_kqvs), masks=masks, corresp=corresp) + all_kqvs.append(kqv) + + h = h.type(x.dtype) + + # print(all_kqvs) + # for i in range(len(all_kqvs)): + # print('len of contexts at ', i, 'is ', len(all_kqvs[i])) + # for j in range(len(all_kqvs[i])): + # print('len of contexts at ', i, j, 'is ', len(all_kqvs[i][j])) + # for k in range(len(all_kqvs[i][j])): + # print(all_kqvs[i][j][k]) + + + + if self.predict_codebook_ids: + out = self.id_predictor(h) + else: + out = self.out(h) + + if self.save_contexts: + self.contexts.append(all_kqvs) + + if get_contexts: + return out, all_kqvs + else: + return out + + def get_contexts(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + same as forward but saves self attention contexts + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + + if self.add_conv_in_front_of_unet: + for module in self.add_resbolck: + h = module(h, emb, context) + + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) + diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..febbddd8416962ca602c18b1f7e3d28ed11572c4 --- /dev/null +++ b/ldm/modules/diffusionmodules/util.py @@ -0,0 +1,285 @@ +# This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and +# Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example +# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. +# CreativeML Open RAIL-M +# +# ========================================================================================== +# +# Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. +# Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit +# LICENSE.md. +# +# ========================================================================================== + +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True, steps=None): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + elif ddim_discr_method == 'manual': + assert steps is not None + ddim_timesteps = np.asarray(steps) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + # @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) # added this for map + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + # @torch.cuda.amp.custom_bwd # added this for map + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/ldm/modules/distributions/__init__.py b/ldm/modules/distributions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/modules/distributions/distributions.py b/ldm/modules/distributions/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..2365b6e583e8248d6b4869f2ff2ca1d89a5835b6 --- /dev/null +++ b/ldm/modules/distributions/distributions.py @@ -0,0 +1,105 @@ +# This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and +# Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example +# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. +# CreativeML Open RAIL-M +# +# ========================================================================================== +# +# Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. +# Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit +# LICENSE.md. +# +# ========================================================================================== + +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/ldm/modules/ema.py b/ldm/modules/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..7e6bd5227e57e6f798a0c523f0b5bdaef00b3dd8 --- /dev/null +++ b/ldm/modules/ema.py @@ -0,0 +1,89 @@ +# This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and +# Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example +# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. +# CreativeML Open RAIL-M +# +# ========================================================================================== +# +# Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. +# Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit +# LICENSE.md. +# +# ========================================================================================== + +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates + else torch.tensor(-1,dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + #remove as '.'-character is not allowed in buffers + s_name = name.replace('.','') + self.m_name2s_name.update({name:s_name}) + self.register_buffer(s_name,p.clone().detach().data) + + self.collected_params = [] + + def forward(self,model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/ldm/modules/encoders/__init__.py b/ldm/modules/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..9118c9684533ed623047eccdf3ee3911a672a411 --- /dev/null +++ b/ldm/modules/encoders/modules.py @@ -0,0 +1,309 @@ +# This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and +# Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example +# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. +# CreativeML Open RAIL-M +# +# ========================================================================================== +# +# Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. +# Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit +# LICENSE.md. +# +# ========================================================================================== + +import torch +import torch.nn as nn +from functools import partial +import clip +from einops import rearrange, repeat +from transformers import CLIPTokenizer, CLIPTextModel,CLIPVisionModel,CLIPModel +import kornia +from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test +from .xf import LayerNorm, Transformer +import math + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class'): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + + def forward(self, batch, key=None): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + c = self.embedding(c) + return c + + +class TransformerEmbedder(AbstractEncoder): + """Some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): + super().__init__() + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer)) + + def forward(self, tokens): + tokens = tokens.to(self.device) # meh + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, x): + return self(x) + + +class BERTTokenizer(AbstractEncoder): + """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" + def __init__(self, device="cuda", vq_interface=True, max_length=77): + super().__init__() + from transformers import BertTokenizerFast # TODO: add to reuquirements + self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") + self.device = device + self.vq_interface = vq_interface + self.max_length = max_length + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + return tokens + + @torch.no_grad() + def encode(self, text): + tokens = self(text) + if not self.vq_interface: + return tokens + return None, None, [None, None, tokens] + + def decode(self, text): + return text + + +class BERTEmbedder(AbstractEncoder): + """Uses the BERT tokenizr model and add some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, + device="cuda",use_tokenizer=True, embedding_dropout=0.0): + super().__init__() + self.use_tknz_fn = use_tokenizer + if self.use_tknz_fn: + self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer), + emb_dropout=embedding_dropout) + + def forward(self, text): + if self.use_tknz_fn: + tokens = self.tknz_fn(text)#.to(self.device) + else: + tokens = text + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, text): + # output of length 77 + return self(text) + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class SpatialRescaler(nn.Module): + def __init__(self, + n_stages=1, + method='bilinear', + multiplier=0.5, + in_channels=3, + out_channels=None, + bias=False): + super().__init__() + self.n_stages = n_stages + assert self.n_stages >= 0 + assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] + self.multiplier = multiplier + self.interpolator = partial(torch.nn.functional.interpolate, mode=method) + self.remap_output = out_channels is not None + if self.remap_output: + print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') + self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) + + def forward(self,x): + for stage in range(self.n_stages): + x = self.interpolator(x, scale_factor=self.multiplier) + + + if self.remap_output: + x = self.channel_mapper(x) + return x + + def encode(self, x): + return self(x) + +class FrozenCLIPTextEmbedder(nn.Module): + """ + Uses the CLIP transformer encoder for text. + """ + def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): + super().__init__() + self.model, _ = clip.load(version, jit=False, device="cpu") + self.device = device + self.max_length = max_length + self.n_repeat = n_repeat + self.normalize = normalize + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = clip.tokenize(text).to(self.device) + z = self.model.encode_text(tokens) + if self.normalize: + z = z / torch.linalg.norm(z, dim=1, keepdim=True) + return z + + def encode(self, text): + z = self(text) + if z.ndim==2: + z = z[:, None, :] + z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) + return z + +class FrozenCLIPImageEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + def __init__(self, version="openai/clip-vit-large-patch14"): + super().__init__() + self.transformer = CLIPVisionModel.from_pretrained(version) + self.final_ln = LayerNorm(1024) + self.mapper = Transformer( + 1, + 1024, + 5, + 1, + ) + + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + for param in self.mapper.parameters(): + param.requires_grad = True + for param in self.final_ln.parameters(): + param.requires_grad = True + + def forward(self, image): + outputs = self.transformer(pixel_values=image) + z = outputs.pooler_output + z = z.unsqueeze(1) + z = self.mapper(z) + z = self.final_ln(z) + return z + + def encode(self, image): + return self(image) + + + +class DINOEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + def __init__(self, dino_version): # small, large, huge, gigantic + super().__init__() + assert dino_version in ['small', 'big', 'large', 'huge'] + letter_map = { + 'small': 's', + 'big': 'b', + 'large': 'l', + 'huge': 'g' + } + + self.final_ln = LayerNorm(32) # unused -- remove later + self.mapper = LayerNorm(32) # unused -- remove later + # embedding_sizes = { + # 'small': 384, + # 'big': 768, + # 'large': 1024, + # 'huge': 1536 + # } + + # embedding_size = embedding_sizes[dino_version] + letter = letter_map[dino_version] + # self.transformer = CLIPVisionModel.from_pretrained(version) + self.dino_model = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{letter}14_reg').cuda() + + + self.freeze() + + def freeze(self): + for param in self.parameters(): + param.requires_grad = False + + def forward(self, image): + with torch.no_grad(): + outputs = self.dino_model.forward_features(image) + patch_tokens = outputs['x_norm_patchtokens'] + global_token = outputs['x_norm_clstoken'].unsqueeze(1) + features = torch.concat([patch_tokens, global_token], dim=1) + return torch.zeros_like(features) + + def encode(self, image): + return self(image) + + +class FixedVector(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + def __init__(self): # small, large, huge, gigantic + super().__init__() + self.final_ln = LayerNorm(32) + self.mapper = LayerNorm(32) + self.fixed_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=True).cuda() + def forward(self, image): + return self.fixed_vector.repeat(image.shape[0],1,1).to(image.device) * 0.0 + + def encode(self, image): + return self(image) + + + + +if __name__ == "__main__": + from ldm.util import count_params + model = FrozenCLIPEmbedder() + count_params(model, verbose=True) \ No newline at end of file diff --git a/ldm/modules/encoders/xf.py b/ldm/modules/encoders/xf.py new file mode 100644 index 0000000000000000000000000000000000000000..972f5e2fdc6adaea630451d3b34938f83db211d9 --- /dev/null +++ b/ldm/modules/encoders/xf.py @@ -0,0 +1,143 @@ +# This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and +# Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example +# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. +# CreativeML Open RAIL-M +# +# ========================================================================================== +# +# Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. +# Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit +# LICENSE.md. +# +# ========================================================================================== + +""" +Transformer implementation adapted from CLIP ViT: +https://github.com/openai/CLIP/blob/4c0275784d6d9da97ca1f47eaaee31de1867da91/clip/model.py +""" + +import math + +import torch as th +import torch.nn as nn + + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + +class LayerNorm(nn.LayerNorm): + """ + Implementation that supports fp16 inputs but fp32 gains/biases. + """ + + def forward(self, x: th.Tensor): + return super().forward(x.float()).to(x.dtype) + + +class MultiheadAttention(nn.Module): + def __init__(self, n_ctx, width, heads): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.heads = heads + self.c_qkv = nn.Linear(width, width * 3) + self.c_proj = nn.Linear(width, width) + self.attention = QKVMultiheadAttention(heads, n_ctx) + + def forward(self, x): + x = self.c_qkv(x) + x = self.attention(x) + x = self.c_proj(x) + return x + + +class MLP(nn.Module): + def __init__(self, width): + super().__init__() + self.width = width + self.c_fc = nn.Linear(width, width * 4) + self.c_proj = nn.Linear(width * 4, width) + self.gelu = nn.GELU() + + def forward(self, x): + return self.c_proj(self.gelu(self.c_fc(x))) + + +class QKVMultiheadAttention(nn.Module): + def __init__(self, n_heads: int, n_ctx: int): + super().__init__() + self.n_heads = n_heads + self.n_ctx = n_ctx + + def forward(self, qkv): + bs, n_ctx, width = qkv.shape + attn_ch = width // self.n_heads // 3 + scale = 1 / math.sqrt(math.sqrt(attn_ch)) + qkv = qkv.view(bs, n_ctx, self.n_heads, -1) + q, k, v = th.split(qkv, attn_ch, dim=-1) + weight = th.einsum( + "bthc,bshc->bhts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + wdtype = weight.dtype + weight = th.softmax(weight.float(), dim=-1).type(wdtype) + return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + n_ctx: int, + width: int, + heads: int, + ): + super().__init__() + + self.attn = MultiheadAttention( + n_ctx, + width, + heads, + ) + self.ln_1 = LayerNorm(width) + self.mlp = MLP(width) + self.ln_2 = LayerNorm(width) + + def forward(self, x: th.Tensor): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, + n_ctx: int, + width: int, + layers: int, + heads: int, + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + n_ctx, + width, + heads, + ) + for _ in range(layers) + ] + ) + + def forward(self, x: th.Tensor): + for block in self.resblocks: + x = block(x) + return x diff --git a/ldm/modules/losses/__init__.py b/ldm/modules/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..876d7c5bd6e3245ee77feb4c482b7a8143604ad5 --- /dev/null +++ b/ldm/modules/losses/__init__.py @@ -0,0 +1 @@ +from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator \ No newline at end of file diff --git a/ldm/modules/losses/contperceptual.py b/ldm/modules/losses/contperceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..d17074d102e97ecb5c0c94a8532681608b73efd7 --- /dev/null +++ b/ldm/modules/losses/contperceptual.py @@ -0,0 +1,124 @@ +# This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and +# Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example +# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. +# CreativeML Open RAIL-M +# +# ========================================================================================== +# +# Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. +# Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit +# LICENSE.md. +# +# ========================================================================================== + +import torch +import torch.nn as nn + +from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? + + +class LPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_loss="hinge"): + + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.kl_weight = kl_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + # output log variance + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm + ).apply(weights_init) + self.discriminator_iter_start = disc_start + self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, inputs, reconstructions, posteriors, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", + weights=None): + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights*nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + if self.disc_factor > 0.0: + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + else: + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), + "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log + diff --git a/ldm/modules/losses/vqperceptual.py b/ldm/modules/losses/vqperceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..e4f129500f3386497a3d23b5b7554d4eecab99f3 --- /dev/null +++ b/ldm/modules/losses/vqperceptual.py @@ -0,0 +1,180 @@ +# This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and +# Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example +# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. +# CreativeML Open RAIL-M +# +# ========================================================================================== +# +# Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. +# Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit +# LICENSE.md. +# +# ========================================================================================== + +import torch +from torch import nn +import torch.nn.functional as F +from einops import repeat + +from taming.modules.discriminator.model import NLayerDiscriminator, weights_init +from taming.modules.losses.lpips import LPIPS +from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss + + +def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): + assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] + loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) + loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) + loss_real = (weights * loss_real).sum() / weights.sum() + loss_fake = (weights * loss_fake).sum() / weights.sum() + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + +def adopt_weight(weight, global_step, threshold=0, value=0.): + if global_step < threshold: + weight = value + return weight + + +def measure_perplexity(predicted_indices, n_embed): + # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py + # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally + encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + return perplexity, cluster_use + +def l1(x, y): + return torch.abs(x-y) + + +def l2(x, y): + return torch.pow((x-y), 2) + + +class VQLPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", + pixel_loss="l1"): + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + assert perceptual_loss in ["lpips", "clips", "dists"] + assert pixel_loss in ["l1", "l2"] + self.codebook_weight = codebook_weight + self.pixel_weight = pixelloss_weight + if perceptual_loss == "lpips": + print(f"{self.__class__.__name__}: Running with LPIPS.") + self.perceptual_loss = LPIPS().eval() + else: + raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") + self.perceptual_weight = perceptual_weight + + if pixel_loss == "l1": + self.pixel_loss = l1 + else: + self.pixel_loss = l2 + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm, + ndf=disc_ndf + ).apply(weights_init) + self.discriminator_iter_start = disc_start + if disc_loss == "hinge": + self.disc_loss = hinge_d_loss + elif disc_loss == "vanilla": + self.disc_loss = vanilla_d_loss + else: + raise ValueError(f"Unknown GAN loss '{disc_loss}'.") + print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + self.n_classes = n_classes + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", predicted_indices=None): + if not exists(codebook_loss): + codebook_loss = torch.tensor([0.]).to(inputs.device) + #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + else: + p_loss = torch.tensor([0.0]) + + nll_loss = rec_loss + #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + nll_loss = torch.mean(nll_loss) + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/quant_loss".format(split): codebook_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/p_loss".format(split): p_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + if predicted_indices is not None: + assert self.n_classes is not None + with torch.no_grad(): + perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) + log[f"{split}/perplexity"] = perplexity + log[f"{split}/cluster_usage"] = cluster_usage + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log diff --git a/ldm/modules/x_transformer.py b/ldm/modules/x_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..2616d387c97d4516790b83356000578a2c596a08 --- /dev/null +++ b/ldm/modules/x_transformer.py @@ -0,0 +1,654 @@ +# This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and +# Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example +# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. +# CreativeML Open RAIL-M +# +# ========================================================================================== +# +# Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. +# Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit +# LICENSE.md. +# +# ========================================================================================== + +"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" +import torch +from torch import nn, einsum +import torch.nn.functional as F +from functools import partial +from inspect import isfunction +from collections import namedtuple +from einops import rearrange, repeat, reduce + +# constants + +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple('Intermediates', [ + 'pre_softmax_attn', + 'post_softmax_attn' +]) + +LayerIntermediates = namedtuple('Intermediates', [ + 'hiddens', + 'attn_intermediates' +]) + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + self.init_() + + def init_(self): + nn.init.normal_(self.emb.weight, std=0.02) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + return self.emb(n)[None, :, :] + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return emb[None, :, :] + + +# helpers + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def always(val): + def inner(*args, **kwargs): + return val + return inner + + +def not_equals(val): + def inner(x): + return x != val + return inner + + +def equals(val): + def inner(x): + return x == val + return inner + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +# keyword argument helpers + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + + +# classes +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.value, *rest) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.g, *rest) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class Residual(nn.Module): + def forward(self, x, residual): + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + + def forward(self, x, residual): + gated_output = self.gru( + rearrange(x, 'b n d -> (b n) d'), + rearrange(residual, 'b n d -> (b n) d') + ) + + return gated_output.reshape_as(x) + + +# feedforward + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +# attention. +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0., + on_attn=False + ): + super().__init__() + if use_entmax15: + raise NotImplementedError("Check out entmax activation instead of softmax activation!") + self.scale = dim_head ** -0.5 + self.heads = heads + self.causal = causal + self.mask = mask + + inner_dim = dim_head * heads + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + self.dropout = nn.Dropout(dropout) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + #self.attn_fn = entmax15 if use_entmax15 else F.softmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + rel_pos=None, + sinusoidal_emb=None, + prev_attn=None, + mem=None + ): + b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + q_mask = rearrange(q_mask, 'b i -> b () i ()') + k_mask = rearrange(k_mask, 'b j -> b () () j') + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots + + if talking_heads: + dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + + if exists(rel_pos): + dots = rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + + intermediates = Intermediates( + pre_softmax_attn=pre_softmax_attn, + post_softmax_attn=post_softmax_attn + ) + + return self.to_out(out), intermediates + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + **kwargs + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) + attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) + + dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + + self.has_pos_emb = position_infused_attn + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + self.rotary_pos_emb = always(None) + + assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + self.rel_pos = None + + self.pre_norm = pre_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ('a', 'c', 'f') + elif cross_attend and only_cross: + default_block = ('c', 'f') + else: + default_block = ('a', 'f') + + if macaron: + default_block = ('f',) + default_block + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, 'par ratio out of range' + default_block = tuple(filter(not_equals('f'), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + ('f',) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ('f',) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' + layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + + for layer_type in self.layer_types: + if layer_type == 'a': + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == 'c': + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == 'f': + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f'invalid layer type {layer_type}') + + if isinstance(layer, Attention) and exists(branch_fn): + layer = branch_fn(layer) + + if gate_residual: + residual_fn = GRUGating(dim) + else: + residual_fn = Residual() + + self.layers.append(nn.ModuleList([ + norm_fn(), + layer, + residual_fn + ])) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + mems=None, + return_hiddens=False + ): + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + is_last = ind == (len(self.layers) - 1) + + if layer_type == 'a': + hiddens.append(x) + layer_mem = mems.pop(0) + + residual = x + + if self.pre_norm: + x = norm(x) + + if layer_type == 'a': + out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, + prev_attn=prev_attn, mem=layer_mem) + elif layer_type == 'c': + out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) + elif layer_type == 'f': + out = block(x) + + x = residual_fn(out, residual) + + if layer_type in ('a', 'c'): + intermediates.append(inter) + + if layer_type == 'a' and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == 'c' and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if not self.pre_norm and not is_last: + x = norm(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, + attn_intermediates=intermediates + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert 'causal' not in kwargs, 'cannot set causality on encoder' + super().__init__(causal=False, **kwargs) + + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0., + emb_dropout=0., + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.num_tokens = num_tokens + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( + use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + # let funnel encoder know number of memory tokens, if specified + if hasattr(attn_layers, 'num_memory_tokens'): + attn_layers.num_memory_tokens = num_memory_tokens + + def init_(self): + nn.init.normal_(self.token_emb.weight, std=0.02) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_mems=False, + return_attn=False, + mems=None, + **kwargs + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x += self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_mems: + hiddens = intermediates.hiddens + new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens + new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) + return out, new_mems + + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + return out, attn_maps + + return out + diff --git a/ldm/util.py b/ldm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..eafc2851d1e1e677b94aa0470ffbc451c37af90b --- /dev/null +++ b/ldm/util.py @@ -0,0 +1,216 @@ +# This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and +# Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example +# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. +# CreativeML Open RAIL-M +# +# ========================================================================================== +# +# Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. +# Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit +# LICENSE.md. +# +# ========================================================================================== + +import importlib + +import torch +import numpy as np +from collections import abc +from einops import rearrange +from functools import partial + +import multiprocessing as mp +from threading import Thread +from queue import Queue + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): + # create dummy dataset instance + + # run prefetching + if idx_to_fn: + res = func(data, worker_id=idx) + else: + res = func(data) + Q.put([idx, res]) + Q.put("Done") + + +def parallel_data_prefetch( + func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False +): + # if target_data_type not in ["ndarray", "list"]: + # raise ValueError( + # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." + # ) + if isinstance(data, np.ndarray) and target_data_type == "list": + raise ValueError("list expected but function got ndarray.") + elif isinstance(data, abc.Iterable): + if isinstance(data, dict): + print( + f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' + ) + data = list(data.values()) + if target_data_type == "ndarray": + data = np.asarray(data) + else: + data = list(data) + else: + raise TypeError( + f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." + ) + + if cpu_intensive: + Q = mp.Queue(1000) + proc = mp.Process + else: + Q = Queue(1000) + proc = Thread + # spawn processes + if target_data_type == "ndarray": + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate(np.array_split(data, n_proc)) + ] + else: + step = ( + int(len(data) / n_proc + 1) + if len(data) % n_proc != 0 + else int(len(data) / n_proc) + ) + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate( + [data[i: i + step] for i in range(0, len(data), step)] + ) + ] + processes = [] + for i in range(n_proc): + p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) + processes += [p] + + # start processes + print(f"Start prefetching...") + import time + + start = time.time() + gather_res = [[] for _ in range(n_proc)] + try: + for p in processes: + p.start() + + k = 0 + while k < n_proc: + # get result + res = Q.get() + if res == "Done": + k += 1 + else: + gather_res[res[0]] = res[1] + + except Exception as e: + print("Exception: ", e) + for p in processes: + p.terminate() + + raise e + finally: + for p in processes: + p.join() + print(f"Prefetching complete. [{time.time() - start} sec.]") + + if target_data_type == 'ndarray': + if not isinstance(gather_res[0], np.ndarray): + return np.concatenate([np.asarray(r) for r in gather_res], axis=0) + + # order outputs + return np.concatenate(gather_res, axis=0) + elif target_data_type == 'list': + out = [] + for r in gather_res: + out.extend(r) + return out + else: + return gather_res diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..f55236db4793f0b8f336bc2d5b945afce01edc84 --- /dev/null +++ b/main.py @@ -0,0 +1,818 @@ +# This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and +# Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example +# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. +# CreativeML Open RAIL-M +# +# ========================================================================================== +# +# Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. +# Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit +# LICENSE.md. +# +# ========================================================================================== + +import argparse, os, sys, datetime, glob, importlib, csv +import numpy as np +import time +import torch +import torchvision +import pytorch_lightning as pl + +from packaging import version +from omegaconf import OmegaConf +from torch.utils.data import random_split, DataLoader, Dataset, Subset +from functools import partial +from PIL import Image + +from pytorch_lightning import seed_everything +from pytorch_lightning.trainer import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor +# from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.rank_zero import rank_zero_only +from pytorch_lightning.utilities import rank_zero_info + +from ldm.data.base import Txt2ImgIterableBaseDataset +from ldm.util import instantiate_from_config +import socket +from pytorch_lightning.plugins.environments import ClusterEnvironment,SLURMEnvironment + +def get_parser(**parser_kwargs): + def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + parser = argparse.ArgumentParser(**parser_kwargs) + parser.add_argument( + "-n", + "--name", + type=str, + const=True, + default="", + nargs="?", + help="postfix for logdir", + ) + parser.add_argument( + "-r", + "--resume", + type=str, + const=True, + default="", + nargs="?", + help="resume from logdir or checkpoint in logdir", + ) + parser.add_argument( + "-b", + "--base", + nargs="*", + metavar="base_config.yaml", + help="paths to base configs. Loaded from left-to-right. " + "Parameters can be overwritten or added with command-line options of the form `--key value`.", + default=["configs/stable-diffusion/v1-inference-inpaint.yaml"], + ) + parser.add_argument( + "-t", + "--train", + type=str2bool, + const=True, + default=True, + nargs="?", + help="train", + ) + parser.add_argument( + "--no-test", + type=str2bool, + const=True, + default=False, + nargs="?", + help="disable test", + ) + parser.add_argument( + "-p", + "--project", + help="name of new or path to existing project" + ) + parser.add_argument( + "-d", + "--debug", + type=str2bool, + nargs="?", + const=True, + default=False, + help="enable post-mortem debugging", + ) + parser.add_argument( + "-s", + "--seed", + type=int, + default=24, + help="seed for seed_everything", + ) + parser.add_argument( + "-f", + "--postfix", + type=str, + default="", + help="post-postfix for default name", + ) + parser.add_argument( + "-l", + "--logdir", + type=str, + default="logs", + help="directory for logging dat shit", + ) + parser.add_argument( + "--pretrained_model", + type=str, + default="", + help="path to pretrained model", + ) + parser.add_argument( + "--scale_lr", + type=str2bool, + nargs="?", + const=True, + default=True, + help="scale base-lr by ngpu * batch_size * n_accumulate", + ) + parser.add_argument( + "--train_from_scratch", + type=str2bool, + nargs="?", + const=True, + default=False, + help="Train from scratch", + ) + return parser + + +def nondefault_trainer_args(opt): + parser = argparse.ArgumentParser() + # parser = Trainer.add_argparse_args(parser) + args = parser.parse_args([]) + return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) + + +class WrappedDataset(Dataset): + """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" + + def __init__(self, dataset): + self.data = dataset + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + +def worker_init_fn(_): + worker_info = torch.utils.data.get_worker_info() + + dataset = worker_info.dataset + worker_id = worker_info.id + + if isinstance(dataset, Txt2ImgIterableBaseDataset): + split_size = dataset.num_records // worker_info.num_workers + # reset num_records to the true number to retain reliable length information + dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] + current_id = np.random.choice(len(np.random.get_state()[1]), 1) + return np.random.seed(np.random.get_state()[1][current_id] + worker_id) + else: + return np.random.seed(np.random.get_state()[1][0] + worker_id) + + +class DataModuleFromConfig(pl.LightningDataModule): + def __init__(self, batch_size, train=None, validation=None, test=None, predict=None, + wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False, + shuffle_val_dataloader=False): + super().__init__() + self.batch_size = batch_size + self.dataset_configs = dict() + self.num_workers = num_workers if num_workers is not None else batch_size * 2 + self.use_worker_init_fn = use_worker_init_fn + if train is not None: + self.dataset_configs["train"] = train + self.train_dataloader = self._train_dataloader + if validation is not None: + self.dataset_configs["validation"] = validation + self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader) + if test is not None: + self.dataset_configs["test"] = test + self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader) + if predict is not None: + self.dataset_configs["predict"] = predict + self.predict_dataloader = self._predict_dataloader + self.wrap = wrap + + def prepare_data(self): + for data_cfg in self.dataset_configs.values(): + instantiate_from_config(data_cfg) + + def setup(self, stage=None): + self.datasets = dict( + (k, instantiate_from_config(self.dataset_configs[k])) + for k in self.dataset_configs) + if self.wrap: + for k in self.datasets: + self.datasets[k] = WrappedDataset(self.datasets[k]) + + def _train_dataloader(self): + is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) + if is_iterable_dataset or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + return DataLoader(self.datasets["train"], batch_size=self.batch_size, + num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True, + worker_init_fn=init_fn) + + def _val_dataloader(self, shuffle=False): + if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + return DataLoader(self.datasets["validation"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + shuffle=shuffle) + + def _test_dataloader(self, shuffle=False): + is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) + if is_iterable_dataset or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + + # do not shuffle dataloader for iterable dataset + shuffle = shuffle and (not is_iterable_dataset) + + return DataLoader(self.datasets["test"], batch_size=self.batch_size, + num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle) + + def _predict_dataloader(self, shuffle=False): + if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + return DataLoader(self.datasets["predict"], batch_size=self.batch_size, + num_workers=self.num_workers, worker_init_fn=init_fn) + + +class SetupCallback(Callback): + def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config): + super().__init__() + self.resume = resume + self.now = now + self.logdir = logdir + self.ckptdir = ckptdir + self.cfgdir = cfgdir + self.config = config + self.lightning_config = lightning_config + + def on_keyboard_interrupt(self, trainer, pl_module): + if trainer.global_rank == 0: + print("Summoning checkpoint.") + ckpt_path = os.path.join(self.ckptdir, "last.ckpt") + trainer.save_checkpoint(ckpt_path) + + def on_pretrain_routine_start(self, trainer, pl_module): + if trainer.global_rank == 0: + # Create logdirs and save configs + os.makedirs(self.logdir, exist_ok=True) + os.makedirs(self.ckptdir, exist_ok=True) + os.makedirs(self.cfgdir, exist_ok=True) + + if "callbacks" in self.lightning_config: + if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']: + os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True) + print("Project config") + print(OmegaConf.to_yaml(self.config)) + OmegaConf.save(self.config, + os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) + + print("Lightning config") + print(OmegaConf.to_yaml(self.lightning_config)) + OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}), + os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now))) + + else: + # ModelCheckpoint callback created log directory --- remove it + if not self.resume and os.path.exists(self.logdir): + dst, name = os.path.split(self.logdir) + dst = os.path.join(dst, "child_runs", name) + os.makedirs(os.path.split(dst)[0], exist_ok=True) + try: + os.rename(self.logdir, dst) + except FileNotFoundError: + pass + + +class ImageLogger(Callback): + def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True, + rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, + log_images_kwargs=None): + super().__init__() + self.rescale = rescale + self.batch_freq = batch_frequency + self.max_images = max_images + self.logger_log_images = { + pl.loggers.CSVLogger: self._testtube, + } + self.log_steps = [2 ** n for n in range(8, int(np.log2(self.batch_freq)) + 1)] + if not increase_log_steps: + self.log_steps = [self.batch_freq] + self.clamp = clamp + self.disabled = disabled + self.log_on_batch_idx = log_on_batch_idx + self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} + self.log_first_step = log_first_step + self.validate_next = 0 + + @rank_zero_only + def _testtube(self, pl_module, images, batch_idx, split): + pass + # for k in images: + # grid = torchvision.utils.make_grid(images[k]) + # grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + + # tag = f"{split}/{k}" + # pl_module.logger.experiment.add_image( + # tag, grid, + # global_step=pl_module.global_step) + + @rank_zero_only + def log_local(self, save_dir, split, images, + global_step, current_epoch, batch_idx): + root = os.path.join(save_dir, "images", split) + if split != 'train': + self.validate_next = 0 + for k in images: + grid = torchvision.utils.make_grid(images[k], nrow=4) + if self.rescale: + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) + grid = grid.numpy() + grid = (grid * 255).astype(np.uint8) + filename = "gs-{:06}_e-{:06}_b-{:06}_{}.jpg".format( + global_step, + current_epoch, + batch_idx, + k) + path = os.path.join(root, filename) + os.makedirs(os.path.split(path)[0], exist_ok=True) + Image.fromarray(grid).save(path) + + def log_img(self, pl_module, batch, batch_idx, split="train"): + check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step + if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 + hasattr(pl_module, "log_images") and + callable(pl_module.log_images) and + self.max_images > 0) or (self.validate_next < 5 and split != "train"): + logger = type(pl_module.logger) + if split != "train": + self.validate_next += 1 + + is_train = pl_module.training + if is_train: + pl_module.eval() + + with torch.no_grad(): + images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) + # print('images keys', images.keys()) + + for k in images: + N = min(images[k].shape[0], self.max_images) + images[k] = images[k][:N] + if isinstance(images[k], torch.Tensor): + images[k] = images[k].detach().cpu() + if self.clamp: + images[k] = torch.clamp(images[k], -1., 1.) + + self.log_local(pl_module.logger.save_dir, split, images, + pl_module.global_step, pl_module.current_epoch, batch_idx) + + logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None) + logger_log_images(pl_module, images, pl_module.global_step, split) + + if is_train: + pl_module.train() + + def check_frequency(self, check_idx): + if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and ( + check_idx > 0 or self.log_first_step): + try: + self.log_steps.pop(0) + except IndexError as e: + print(e) + pass + return True + return False + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): + self.log_img(pl_module, batch, batch_idx, split="train") + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if not self.disabled and pl_module.global_step > 0: + self.log_img(pl_module, batch, batch_idx, split="val") + if hasattr(pl_module, 'calibrate_grad_norm'): + if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0: + self.log_gradients(trainer, pl_module, batch_idx=batch_idx) + + +class CUDACallback(Callback): + # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py + def on_train_epoch_start(self, trainer, pl_module): + # Reset the memory use counter + torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device) + torch.cuda.synchronize(trainer.strategy.root_device) + self.start_time = time.time() + + def on_train_epoch_end(self, trainer, pl_module): + torch.cuda.synchronize(trainer.strategy.root_device) + max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device) / 2 ** 20 + epoch_time = time.time() - self.start_time + + try: + max_memory = trainer.training_type_plugin.reduce(max_memory) + epoch_time = trainer.training_type_plugin.reduce(epoch_time) + + rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") + rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") + except AttributeError: + pass + + +if __name__ == "__main__": + + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + sys.path.append(os.getcwd()) + + torch.set_float32_matmul_precision('medium') + + parser = get_parser() + # parser = Trainer.add_argparse_args(parser) + + opt, unknown = parser.parse_known_args() + if opt.name and opt.resume: + raise ValueError( + "-n/--name and -r/--resume cannot be specified both." + "If you want to resume training in a new log folder, " + "use -n/--name in combination with --resume_from_checkpoint" + ) + if opt.resume: + if not os.path.exists(opt.resume): + raise ValueError("Cannot find {}".format(opt.resume)) + if os.path.isfile(opt.resume): + paths = opt.resume.split("/") + # idx = len(paths)-paths[::-1].index("logs")+1 + # logdir = "/".join(paths[:idx]) + logdir = "/".join(paths[:-2]) + ckpt = opt.resume + else: + assert os.path.isdir(opt.resume), opt.resume + logdir = opt.resume.rstrip("/") + ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") + + opt.resume_from_checkpoint = ckpt + base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) + opt.base = base_configs + opt.base + _tmp = logdir.split("/") + nowname = _tmp[-1] + else: + if opt.name: + name = "_" + opt.name + elif opt.base: + cfg_fname = os.path.split(opt.base[0])[-1] + cfg_name = os.path.splitext(cfg_fname)[0] + name = "_" + cfg_name + else: + name = "" + nowname = now + name + opt.postfix + logdir = os.path.join(opt.logdir, nowname) + + ckptdir = os.path.join(logdir, "checkpoints") + cfgdir = os.path.join(logdir, "configs") + seed_everything(opt.seed) + + # try: + # init and save configs + configs = [OmegaConf.load(cfg) for cfg in opt.base] + cli = OmegaConf.from_dotlist(unknown) + config = OmegaConf.merge(*configs, cli) + lightning_config = config.pop("lightning", OmegaConf.create()) + # merge trainer cli with config + trainer_config = lightning_config.get("trainer", OmegaConf.create()) + # default to ddp + trainer_config["accelerator"] = "ddp" + for k in nondefault_trainer_args(opt): + trainer_config[k] = getattr(opt, k) + if not "gpus" in trainer_config: + del trainer_config["accelerator"] + cpu = True + else: + gpuinfo = trainer_config["gpus"] + print(f"Running on GPUs {gpuinfo}") + cpu = False + trainer_opt = argparse.Namespace(**trainer_config) + lightning_config.trainer = trainer_config + + # model + def load_model_from_config(config, ckpt, verbose=False): + model = instantiate_from_config(config.model) + # print('NOTE: NO CHECKPOINT IS LOADED') + + if ckpt is not None: + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + # sd = pl_sd["state_dict"] + + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + else: + print('NO CHECKPOINT LOADED') + + model.train() + return model + + + def get_model(config_path, ckpt_path): + config = OmegaConf.load(f"{config_path}") + model = load_model_from_config(config, None) + model.load_state_dict(torch.load('pretrained_models/sd-v1-4-modified-9channel.ckpt',map_location='cpu')['state_dict'],strict=False) + print(model.training) + # raise ValueError + pl_sd = torch.load(ckpt_path, map_location="cpu") + wrapped_state_dict = pl_sd #self.lightning_module.trainer.model.state_dict() + new_sd = {k.replace("_forward_module.", ""): wrapped_state_dict[k] for k in wrapped_state_dict} + + m, u = model.load_state_dict(new_sd, strict=False) + if len(m) > 0: + print("missing keys:") + print(m) + if len(u) > 0: + print("unexpected keys:") + print(u) + + + # model = model.to(device) + return model + config_path = 'configs/collage_flow_train.yaml' + + # model = get_model('configs/collage_mix_train.yaml', '/sensei-fs/users/halzayer/collage2photo/Paint-by-Example/official_checkpoint_image_attn_80k.pt') + + model = instantiate_from_config(config.model) + if not opt.resume: + if opt.train_from_scratch: + ckpt_file=torch.load(opt.pretrained_model,map_location='cpu')['state_dict'] + ckpt_file={key:value for key,value in ckpt_file.items() if not ( key[:6]=='model.')} + model.load_state_dict(ckpt_file,strict=False) + print("Train from scratch!") + else: + sd = torch.load(opt.pretrained_model,map_location='cpu')['state_dict'] + + new_dict = dict() + for key in sd: + if 'attn1' in key: + new_dict[key.replace('attn1', 'attn3')] = sd[key] + for key in sd: + if 'diffusion_model' in key: + new_dict[key.replace('diffusion_model', 'reference_model')] = sd[key] + for new_key in new_dict: + sd[new_key] = new_dict[new_key] + model.load_state_dict(sd ,strict=False) + print("Load Stable Diffusion v1-4!") + + # trainer and callbacks + trainer_kwargs = dict() + + # default logger configs + default_logger_cfgs = { + "wandb": { + "target": "pytorch_lightning.loggers.WandbLogger", + "params": { + "name": nowname, + "save_dir": logdir, + "offline": opt.debug, + "id": nowname, + } + }, + "testtube": { + "target": "pytorch_lightning.loggers.CSVLogger", + "params": { + "name": "testtube", + "save_dir": logdir, + } + }, + } + default_logger_cfg = default_logger_cfgs["testtube"] + if "logger" in lightning_config: + logger_cfg = lightning_config.logger + else: + logger_cfg = OmegaConf.create() + logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) + trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) + + # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to + # specify which metric is used to determine best models + default_modelckpt_cfg = { + "target": "pytorch_lightning.callbacks.ModelCheckpoint", + "params": { + "dirpath": ckptdir, + "filename": "{epoch:06}", + "verbose": True, + "save_last": True, + "every_n_train_steps": 2500, + "save_weights_only":True + } + } + if hasattr(model, "monitor"): + print(f"Monitoring {model.monitor} as checkpoint metric.") + default_modelckpt_cfg["params"]["monitor"] = model.monitor + default_modelckpt_cfg["params"]["save_top_k"] = 1 + default_modelckpt_cfg["params"]["save_weights_only"] = True + + + if "modelcheckpoint" in lightning_config: + modelckpt_cfg = lightning_config.modelcheckpoint + else: + modelckpt_cfg = OmegaConf.create() + modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) + print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}") + if version.parse(pl.__version__) < version.parse('1.4.0'): + trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg) + + # add callback which sets up log directory + default_callbacks_cfg = { + "setup_callback": { + "target": "main.SetupCallback", + "params": { + "resume": opt.resume, + "now": now, + "logdir": logdir, + "ckptdir": ckptdir, + "cfgdir": cfgdir, + "config": config, + "lightning_config": lightning_config, + } + }, + "image_logger": { + "target": "main.ImageLogger", + "params": { + "batch_frequency": 500, + "max_images": 4, + "clamp": True + } + }, + "learning_rate_logger": { + "target": "main.LearningRateMonitor", + "params": { + "logging_interval": "step", + # "log_momentum": True + } + }, + "cuda_callback": { + "target": "main.CUDACallback" + }, + } + if version.parse(pl.__version__) >= version.parse('1.4.0'): + default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg}) + + if "callbacks" in lightning_config: + callbacks_cfg = lightning_config.callbacks + else: + callbacks_cfg = OmegaConf.create() + + if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg: + print( + 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.') + default_metrics_over_trainsteps_ckpt_dict = { + 'metrics_over_trainsteps_checkpoint': + {"target": 'pytorch_lightning.callbacks.ModelCheckpoint', + 'params': { + "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), + "filename": "{epoch:06}-{step:09}", + "verbose": True, + 'save_top_k': -1, + 'every_n_train_steps': 10000, + 'save_weights_only': True + } + } + } + default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) + + callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) + if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'): + callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint + elif 'ignore_keys_callback' in callbacks_cfg: + del callbacks_cfg['ignore_keys_callback'] + + trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] + + + # max_epochs: 500 + # num_nodes: 1 + # num_sanity_val_steps: 0 + # accelerator: 'gpu' + # gpus: "0,1,2,3,4,5,6,7" + # trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) + print('trainer kwargs', trainer_kwargs) + # TODO change val batches again + # trainer = Trainer(num_nodes=1, num_sanity_val_steps=1, max_epochs=300, accelerator='gpu', devices=1, + # limit_val_batches=1.0, limit_test_batches=1.0, **trainer_kwargs) + # trainer = Trainer(num_nodes=1, num_sanity_val_steps=1, max_epochs=300, accelerator='gpu', strategy='ddp_find_unused_parameters_true', devices=8, + # limit_val_batches=1.0, limit_test_batches=1.0, **trainer_kwargs) + + # this is the right one + trainer = Trainer(num_nodes=1, num_sanity_val_steps=1, max_epochs=8, accelerator='gpu', strategy="deepspeed_stage_1", precision="32", devices=8, + limit_val_batches=1.0, limit_test_batches=1.0, **trainer_kwargs) + # trainer = Trainer(num_nodes=1, num_sanity_val_steps=1, max_epochs=10000, accelerator='gpu', strategy="deepspeed_stage_1", precision="32", devices=8, + # overfit_batches=20,limit_val_batches=0.5, limit_test_batches=0.5, check_val_every_n_epoch=1000, **trainer_kwargs) + # trainer.plugins = [MyCluster()] + trainer.logdir = logdir ### + + # data + data = instantiate_from_config(config.data) + # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html + # calling these ourselves should not be necessary but it is. + # lightning still takes care of proper multiprocessing though + data.prepare_data() + data.setup() + print("#### Data #####") + for k in data.datasets: + print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") + + # configure learning rate + bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate + if not cpu: + ngpu = len(lightning_config.trainer.gpus.strip(",").split(',')) + else: + ngpu = 1 + if 'accumulate_grad_batches' in lightning_config.trainer: + accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches + else: + accumulate_grad_batches = 1 + # if 'num_nodes' in lightning_config.trainer: + # num_nodes = lightning_config.trainer.num_nodes + # else: + num_nodes = 1 + print(f"accumulate_grad_batches = {accumulate_grad_batches}") + lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches + if opt.scale_lr: + model.learning_rate = accumulate_grad_batches * num_nodes * ngpu * bs * base_lr + print( + "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_nodes) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format( + model.learning_rate, accumulate_grad_batches, num_nodes, ngpu, bs, base_lr)) + else: + model.learning_rate = base_lr + print("++++ NOT USING LR SCALING ++++") + print(f"Setting learning rate to {model.learning_rate:.2e}") + + + # allow checkpointing via USR1 + def melk(*args, **kwargs): + # run all checkpoint hooks + if trainer.global_rank == 0: + print("Summoning checkpoint.") + ckpt_path = os.path.join(ckptdir, "last.ckpt") + trainer.save_checkpoint(ckpt_path) + + + def divein(*args, **kwargs): + if trainer.global_rank == 0: + import pudb; + pudb.set_trace() + + + import signal + + signal.signal(signal.SIGUSR1, melk) + signal.signal(signal.SIGUSR2, divein) + + # run + if opt.train: + try: + if opt.resume: + trainer.fit(model, data, ckpt_path=opt.resume) + else: + trainer.fit(model, data) + except Exception: + melk() + raise + if not opt.no_test and not trainer.interrupted: + trainer.test(model, data) diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 73d01db64bd054c7d21fd0bd79b3af087c468809..0000000000000000000000000000000000000000 --- a/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -accelerate -diffusers -invisible_watermark -torch -transformers -xformers \ No newline at end of file diff --git a/run_magicfu.py b/run_magicfu.py new file mode 100644 index 0000000000000000000000000000000000000000..c4cf0161601e3cd77fe0b571336e12ab8dec235f --- /dev/null +++ b/run_magicfu.py @@ -0,0 +1,319 @@ +# Copyright 2024 Adobe. All rights reserved. + +#%% +import cv2 +import torch +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from itertools import islice +from torch import autocast +import torchvision +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from torchvision.transforms import Resize +import argparse +import os +import pathlib +import glob + + +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + +def fix_img(test_img): + width, height = test_img.size + if width != height: + left = 0 + right = height + bottom = height + top = 0 + return test_img.crop((left, top, right, bottom)) + else: + return test_img +# util funcs +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + +def get_tensor_clip(normalize=True, toTensor=True): + transform_list = [] + if toTensor: + transform_list += [torchvision.transforms.ToTensor()] + + if normalize: + transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711))] + return torchvision.transforms.Compose(transform_list) + +def get_tensor_dino(normalize=True, toTensor=True): + transform_list = [torchvision.transforms.Resize((224,224))] + if toTensor: + transform_list += [torchvision.transforms.ToTensor()] + + if normalize: + transform_list += [lambda x: 255.0 * x[:3], + torchvision.transforms.Normalize( + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + )] + return torchvision.transforms.Compose(transform_list) + +def get_tensor(normalize=True, toTensor=True): + transform_list = [] + if toTensor: + transform_list += [torchvision.transforms.ToTensor()] + + if normalize: + transform_list += [torchvision.transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + transform_list += [ + torchvision.transforms.Resize(512), + torchvision.transforms.CenterCrop(512) + ] + return torchvision.transforms.Compose(transform_list) + + +def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + + +def load_model_from_config(config, ckpt, verbose=False): + model = instantiate_from_config(config.model) + # print('NOTE: NO CHECKPOINT IS LOADED') + + if ckpt is not None: + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + # sd = pl_sd["state_dict"] + + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + + +def get_model(config_path, ckpt_path): + config = OmegaConf.load(f"{config_path}") + model = load_model_from_config(config, None) + pl_sd = torch.load(ckpt_path, map_location="cpu") + + m, u = model.load_state_dict(pl_sd, strict=True) + if len(m) > 0: + print("WARNING: missing keys:") + print(m) + if len(u) > 0: + print("unexpected keys:") + print(u) + + + model = model.to(device) + return model + +def get_grid(size): + y = np.repeat(np.arange(size)[None, ...], size) + y = y.reshape(size, size) + x = y.transpose() + out = np.stack([y,x], -1) + return out + +def un_norm(x): + return (x+1.0)/2.0 + +class MagicFixup: + def __init__(self, model_path='/sensei-fs/users/halzayer/collage2photo/Paint-by-Example/official_checkpoint_image_attn_200k.pt'): + self.model = get_model('configs/collage_mix_train.yaml',model_path) + + + def edit_image(self, ref_image, coarse_edit, mask_tensor, start_step, steps): + # essentially sample + sampler = DDIMSampler(self.model) + + start_code = None + + transformed_grid = torch.zeros((2, 64, 64)) + + self.model.model.og_grid = None + self.model.model.transformed_grid = transformed_grid.unsqueeze(0).to(self.model.device) + + scale = 1.0 + C, f, H, W= 4, 8, 512, 512 + n_samples = 1 + ddim_steps = steps + ddim_eta = 1.0 + step = start_step + + with torch.no_grad(): + with autocast("cuda"): + with self.model.ema_scope(): + image_tensor = get_tensor(toTensor=False)(coarse_edit) + + clean_ref_tensor = get_tensor(toTensor=False)(ref_image) + clean_ref_tensor = clean_ref_tensor.unsqueeze(0) + + ref_tensor=get_tensor_dino(toTensor=False)(ref_image).unsqueeze(0) + + b_mask = mask_tensor.cpu() < 0.5 + + # inpainting + reference = un_norm(image_tensor) + reference = reference.squeeze() + ref_cv = torch.moveaxis(reference, 0, -1).cpu().numpy() + ref_cv = (ref_cv * 255).astype(np.uint8) + + cv_mask = b_mask.int().squeeze().cpu().numpy().astype(np.uint8) + kernel = np.ones((7,7)) + dilated_mask = cv2.dilate(cv_mask, kernel) + + dst = cv2.inpaint(ref_cv,dilated_mask,3,cv2.INPAINT_NS) + # dst = inpaint.inpaint_biharmonic(ref_cv, dilated_mask, channel_axis=-1) + dst_tensor = torch.tensor(dst).moveaxis(-1, 0) / 255.0 + image_tensor = (dst_tensor * 2.0) - 1.0 + image_tensor = image_tensor.unsqueeze(0) + + ref_tensor = ref_tensor + + inpaint_image = image_tensor#*mask_tensor + + test_model_kwargs={} + test_model_kwargs['inpaint_mask']=mask_tensor.to(device) + test_model_kwargs['inpaint_image']=inpaint_image.to(device) + clean_ref_tensor = clean_ref_tensor.to(device) + ref_tensor=ref_tensor.to(device) + uc = None + if scale != 1.0: + uc = self.model.learnable_vector + c = self.model.get_learned_conditioning(ref_tensor.to(torch.float16)) + c = self.model.proj_out(c) + + z_inpaint = self.model.encode_first_stage(test_model_kwargs['inpaint_image']) + z_inpaint = self.model.get_first_stage_encoding(z_inpaint).detach() + + + z_ref = self.model.encode_first_stage(clean_ref_tensor) + z_ref = self.model.get_first_stage_encoding(z_ref).detach() + + test_model_kwargs['inpaint_image']=z_inpaint + test_model_kwargs['inpaint_mask']=Resize([z_inpaint.shape[-2],z_inpaint.shape[-1]])(test_model_kwargs['inpaint_mask']) + + + shape = [C, H // f, W // f] + + samples_ddim, _ = sampler.sample(S=ddim_steps, + conditioning=c, + z_ref=z_ref, + batch_size=n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=scale, + unconditional_conditioning=uc, + eta=ddim_eta, + x_T=start_code, + test_model_kwargs=test_model_kwargs, + x0=z_inpaint, + x0_step=step, + ddim_discretize='uniform', + drop_latent_guidance=1.0 + ) + + + x_samples_ddim = self.model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() + + x_checked_image=x_samples_ddim + x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) + + + return x_checked_image_torch +#%% + + +#%% +import time + + + +# %% +def file_exists(path): + """ Check if a file exists and is not a directory. """ + if not os.path.isfile(path): + raise argparse.ArgumentTypeError(f"{path} is not a valid file.") + return path + +def parse_arguments(): + """ Parses command-line arguments. """ + parser = argparse.ArgumentParser(description="Process images based on provided paths.") + parser.add_argument("--checkpoint", type=file_exists, required=True, help="Path to the MagicFixup checkpoint file.") + parser.add_argument("--reference", type=file_exists, default='examples/fox_drinking_og.png', help="Path to the reference original image.") + parser.add_argument("--edit", type=file_exists, default='examples/fox_drinking__edit__01.png', help="Path to the image edit. Make sure the alpha channel is set properly") + parser.add_argument("--output-dir", type=str, default='./outputs', help="Path to the folder where to save the outputs") + parser.add_argument("--samples", type=int, default=5, help="number of samples to output") + + return parser.parse_args() + + +def main(): + # Parse arguments + args = parse_arguments() + + # create magic fixup model + magic_fixup = MagicFixup(model_path=args.checkpoint) + output_dir = args.output_dir + + os.makedirs(output_dir, exist_ok=True) + + # run it here + + to_tensor = torchvision.transforms.ToTensor() + + + + ref_path = args.reference + coarse_edit_path = args.edit + mask_edit_path = coarse_edit_path + + edit_file_name = pathlib.Path(coarse_edit_path).stem + save_pattern = f'{output_dir}/{edit_file_name}__sample__*.png' + save_counter = len(glob.glob(save_pattern)) + + all_rgbs = [] + for i in range(args.samples): + with autocast("cuda"): + ref_image_t = to_tensor(Image.open(ref_path).convert('RGB').resize((512,512))).half().cuda() + coarse_edit_t = to_tensor(Image.open(coarse_edit_path).resize((512,512))).half().cuda() + # get mask from coarse + # mask_t = torch.ones_like(coarse_edit_t[-1][None, None,...]) + coarse_edit_mask_t = to_tensor(Image.open(mask_edit_path).resize((512,512))).half().cuda() + # get mask from coarse + mask_t = (coarse_edit_mask_t[-1][None, None,...]).half() # do center crop + coarse_edit_t_rgb = coarse_edit_t[:-1] + + out_rgb = magic_fixup.edit_image(ref_image_t, coarse_edit_t_rgb, mask_t, start_step=1.0, steps=50) + all_rgbs.append(out_rgb.squeeze().cpu().detach().float()) + + save_path = f'{output_dir}/{edit_file_name}__sample__{save_counter:03d}.png' + torchvision.utils.save_image(all_rgbs[i], save_path) + save_counter += 1 + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/combine_model_params.py b/scripts/combine_model_params.py new file mode 100644 index 0000000000000000000000000000000000000000..ad691042e970e8f8c0a351fe587645e2fe3d5f6f --- /dev/null +++ b/scripts/combine_model_params.py @@ -0,0 +1,79 @@ +# Copyright 2024 Adobe. All rights reserved. +#%% +import matplotlib.pyplot as plt +import cv2 +import torch +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from imwatermark import WatermarkEncoder +from itertools import islice +import time +from pytorch_lightning import seed_everything +from torch import autocast +import torchvision +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from torchvision.transforms import Resize +import argparse +import os +import pathlib +import glob +import tqdm + + +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + +def load_model_from_config(config): + model = instantiate_from_config(config.model) + + model.cuda() + model.eval() + return model + + +def get_model(config_path, ckpt_path, pretrained_sd_path): + config = OmegaConf.load(f"{config_path}") + model = load_model_from_config(config) + model.load_state_dict(torch.load(pretrained_sd_path,map_location='cpu')['state_dict'],strict=False) + + + pl_sd = torch.load(ckpt_path, map_location="cpu") + wrapped_state_dict = pl_sd #self.lightning_module.trainer.model.state_dict() + new_sd = {k.replace("_forward_module.", ""): wrapped_state_dict[k] for k in wrapped_state_dict} + + m, u = model.load_state_dict(new_sd, strict=False) + if len(m) > 0: + print("missing keys:") + print(m) + if len(u) > 0: + print("unexpected keys:") + print(u) + + + model = model.to(device) + return model + +def file_exists(path): + """ Check if a file exists and is not a directory. """ + if not os.path.isfile(path): + raise argparse.ArgumentTypeError(f"{path} is not a valid file.") + return path + +def parse_arguments(): + """ Parses command-line arguments. """ + parser = argparse.ArgumentParser(description="Process images based on provided paths.") + parser.add_argument("--pretrained_sd", type=file_exists, required=True, help="Path to the SD1.4 pretrained checkpoint") + parser.add_argument("--learned_params", type=file_exists, required=True, help="Path to the MagicFixup learned parameters.") + parser.add_argument("--save_path", type=str, required=True, help="Path to save the full model state dict") + + return parser.parse_args() + +def main(): + args = parse_arguments() + model = get_model('configs/collage_mix_train.yaml',args.learned_params, args.pretrained_sd) + torch.save(model.state_dict(), args.save_path) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/scripts/inference.py b/scripts/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..4863922268985290d0d05d75a84bf75c0a5c9fa6 --- /dev/null +++ b/scripts/inference.py @@ -0,0 +1,449 @@ +# Copyright 2024 Adobe. All rights reserved. +import argparse, os, sys, glob +# sys.path.append('.') +import cv2 +import torch +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from imwatermark import WatermarkEncoder +from itertools import islice +from einops import rearrange +from torchvision.utils import make_grid +import time +from pytorch_lightning import seed_everything +from torch import autocast +from contextlib import contextmanager, nullcontext +import torchvision +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler + +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from transformers import AutoFeatureExtractor +import clip +from torchvision.transforms import Resize +import json +wm = "Paint-by-Example" +wm_encoder = WatermarkEncoder() +wm_encoder.set_watermark('bytes', wm.encode('utf-8')) +safety_model_id = "CompVis/stable-diffusion-safety-checker" +safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) +safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + +def get_tensor_clip(normalize=True, toTensor=True): + transform_list = [] + if toTensor: + transform_list += [torchvision.transforms.ToTensor()] + + if normalize: + transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711))] + return torchvision.transforms.Compose(transform_list) + +def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + # print('NOTE: NO CHECKPOINT IS LOADED') + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + + +def put_watermark(img, wm_encoder=None): + if wm_encoder is not None: + img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + img = wm_encoder.encode(img, 'dwtDct') + img = Image.fromarray(img[:, :, ::-1]) + return img + + +def load_replacement(x): + try: + hwc = x.shape + y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0])) + y = (np.array(y)/255.0).astype(x.dtype) + assert y.shape == x.shape + return y + except Exception: + return x + + +def check_safety(x_image): + safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") + x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) + assert x_checked_image.shape[0] == len(has_nsfw_concept) + for i in range(len(has_nsfw_concept)): + if has_nsfw_concept[i]: + x_checked_image[i] = load_replacement(x_checked_image[i]) + return x_checked_image, has_nsfw_concept + +def get_tensor(normalize=True, toTensor=True): + transform_list = [] + if toTensor: + transform_list += [torchvision.transforms.ToTensor()] + + if normalize: + transform_list += [torchvision.transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + transform_list += [ + torchvision.transforms.Resize(512), + torchvision.transforms.CenterCrop(512) + ] + return torchvision.transforms.Compose(transform_list) + +def get_tensor_clip(normalize=True, toTensor=True): + transform_list = [] + if toTensor: + transform_list += [torchvision.transforms.ToTensor()] + + if normalize: + transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711))] + return torchvision.transforms.Compose(transform_list) + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/txt2img-samples" + ) + parser.add_argument( + "--skip_grid", + action='store_true', + help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", + ) + parser.add_argument( + "--skip_save", + action='store_true', + help="do not save individual samples. For speed measurements.", + ) + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + parser.add_argument( + "--plms", + action='store_true', + help="use plms sampling", + ) + parser.add_argument( + "--fixed_code", + action='store_true', + help="if enabled, uses the same starting code across samples ", + ) + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=2, + help="sample this often", + ) + parser.add_argument( + "--H", + type=int, + default=512, + help="image height, in pixel space", + ) + parser.add_argument( + "--W", + type=int, + default=512, + help="image width, in pixel space", + ) + parser.add_argument( + "--n_imgs", + type=int, + default=100, + help="image width, in pixel space", + ) + parser.add_argument( + "--C", + type=int, + default=4, + help="latent channels", + ) + parser.add_argument( + "--f", + type=int, + default=8, + help="downsampling factor", + ) + parser.add_argument( + "--n_samples", + type=int, + default=1, + help="how many samples to produce for each given reference image. A.k.a. batch size", + ) + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + parser.add_argument( + "--scale", + type=float, + default=1, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + parser.add_argument( + "--config", + type=str, + default="", + help="path to config which constructs model", + ) + parser.add_argument( + "--ckpt", + type=str, + default="", + help="path to checkpoint of model", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="the seed (for reproducible sampling)", + ) + parser.add_argument( + "--precision", + type=str, + help="evaluate at this precision", + choices=["full", "autocast"], + default="autocast" + ) + parser.add_argument( + "--image_path", + type=str, + help="evaluate at this precision", + default="" + ) + parser.add_argument( + "--mask_path", + type=str, + help="evaluate at this precision", + default="" + ) + parser.add_argument( + "--reference_path", + type=str, + help="evaluate at this precision", + default="" + ) + opt = parser.parse_args() + + + seed_everything(opt.seed) + + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + if opt.plms: + sampler = PLMSSampler(model) + else: + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + outpath = opt.outdir + + + batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + + sample_path = os.path.join(outpath, "source") + result_path = os.path.join(outpath, "results") + grid_path=os.path.join(outpath, "grid") + os.makedirs(sample_path, exist_ok=True) + os.makedirs(result_path, exist_ok=True) + os.makedirs(grid_path, exist_ok=True) + + + start_code = None + if opt.fixed_code: + start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) + + precision_scope = autocast if opt.precision=="autocast" else nullcontext + + # split_path = '' + # with open(split_path) as f: + # sample_paths = json.load(f) + + # np.random.seed(opt.seed) + # np.random.shuffle(sample_paths) + + # print(sample_paths[0]) + # raise ValueError + + with torch.no_grad(): + with precision_scope("cuda"): + for i in range(1): + with model.ema_scope(): + filename=os.path.basename(opt.image_path) + img_p = Image.open(opt.image_path).convert("RGB") + image_tensor = get_tensor()(img_p) + image_tensor = image_tensor.unsqueeze(0) + ref_p = Image.open(opt.reference_path).convert("RGB") + width, height = ref_p.size # Get dimensions + new_width = min(width, height) + new_height = new_width + + left = (width - new_width)/2 + top = (height - new_height)/2 + right = (width + new_width)/2 + bottom = (height + new_height)/2 + + # Crop the center of the image + ref_p = ref_p.crop((left, top, right, bottom)) + ref_p = ref_p.resize((224,224)) + ref_tensor=get_tensor_clip()(ref_p) + ref_tensor = ref_tensor.unsqueeze(0) + mask=Image.open(opt.mask_path).convert("L") + mask = mask.crop((left, top, right, bottom)) + mask = np.array(mask)[None,None] + mask = mask.astype(np.float32)/255.0 + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask_tensor = torch.from_numpy(mask) + inpaint_image = image_tensor#*mask_tensor + # mask_tensor = torch.ones_like(inpaint_image) + # mask_tensor = mask_tensor[:, :1] # TODO PLEASE COMMENT OUT SOON + print('inpaint image size', inpaint_image.shape) + test_model_kwargs={} + test_model_kwargs['inpaint_mask']=mask_tensor.to(device) + test_model_kwargs['inpaint_image']=inpaint_image.to(device) + ref_tensor=ref_tensor.to(device) + uc = None + if opt.scale != 1.0: + uc = model.learnable_vector + c = model.get_learned_conditioning(ref_tensor.to(torch.float16)) + c = model.proj_out(c) + inpaint_mask=test_model_kwargs['inpaint_mask'] + z_inpaint = model.encode_first_stage(test_model_kwargs['inpaint_image']) + z_inpaint = model.get_first_stage_encoding(z_inpaint).detach() + test_model_kwargs['inpaint_image']=z_inpaint + test_model_kwargs['inpaint_mask']=Resize([z_inpaint.shape[-2],z_inpaint.shape[-1]])(test_model_kwargs['inpaint_mask']) + + shape = [opt.C, opt.H // opt.f, opt.W // opt.f] + + + # compute context here + + # ref_latent = model.encode_first_stage(img maybe) + # contexts = context_unet.compute_context() + + samples_ddim, _ = sampler.sample(S=opt.ddim_steps, + conditioning=c, + batch_size=opt.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + x_T=start_code, + test_model_kwargs=test_model_kwargs) + + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() + + # x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim) + x_checked_image=x_samples_ddim + x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) + + def un_norm(x): + return (x+1.0)/2.0 + def un_norm_clip(x): + x[0,:,:] = x[0,:,:] * 0.26862954 + 0.48145466 + x[1,:,:] = x[1,:,:] * 0.26130258 + 0.4578275 + x[2,:,:] = x[2,:,:] * 0.27577711 + 0.40821073 + return x + + if not opt.skip_save: + for i,x_sample in enumerate(x_checked_image_torch): + + + all_img=[] + all_img.append(un_norm(image_tensor[i]).cpu()) + all_img.append(un_norm(inpaint_image[i]).cpu()) + ref_img=ref_tensor + ref_img=Resize([opt.H, opt.W])(ref_img) + all_img.append(un_norm_clip(ref_img[i]).cpu()) + all_img.append(x_sample) + grid = torch.stack(all_img, 0) + grid = make_grid(grid) + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + img = Image.fromarray(grid.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(grid_path, 'grid-'+filename[:-4]+'_'+str(opt.seed)+f'_{i}.png')) + + + + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + img = Image.fromarray(x_sample.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(result_path, filename[:-4]+'_'+str(opt.seed)+f"_{i}.png")) + + mask_save=255.*rearrange(un_norm(inpaint_mask[i]).cpu(), 'c h w -> h w c').numpy() + mask_save= cv2.cvtColor(mask_save,cv2.COLOR_GRAY2RGB) + mask_save = Image.fromarray(mask_save.astype(np.uint8)) + mask_save.save(os.path.join(sample_path, filename[:-4]+'_'+str(opt.seed)+f"_mask_{i}.png")) + GT_img=255.*rearrange(all_img[0], 'c h w -> h w c').numpy() + GT_img = Image.fromarray(GT_img.astype(np.uint8)) + GT_img.save(os.path.join(sample_path, filename[:-4]+'_'+str(opt.seed)+f"_GT_{i}.png")) + inpaint_img=255.*rearrange(all_img[1], 'c h w -> h w c').numpy() + inpaint_img = Image.fromarray(inpaint_img.astype(np.uint8)) + inpaint_img.save(os.path.join(sample_path, filename[:-4]+'_'+str(opt.seed)+f"_inpaint_{i}.png")) + ref_img=255.*rearrange(all_img[2], 'c h w -> h w c').numpy() + ref_img = Image.fromarray(ref_img.astype(np.uint8)) + ref_img.save(os.path.join(sample_path, filename[:-4]+'_'+str(opt.seed)+f"_ref_{i}.png")) + + print(f"Your samples are ready and waiting for you here: \n{outpath} \n" + f" \nEnjoy.") + + +if __name__ == "__main__": + main() diff --git a/scripts/modify_checkpoints.py b/scripts/modify_checkpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..aa4c9eed65cbe8cdc89f2b62553c45b1344e07b3 --- /dev/null +++ b/scripts/modify_checkpoints.py @@ -0,0 +1,8 @@ +# Copyright 2024 Adobe. All rights reserved. +import torch +pretrained_model_path='pretrained_models/sd-v1-4.ckpt' +ckpt_file=torch.load(pretrained_model_path,map_location='cpu') +zero_data=torch.zeros(320,5,3,3) +new_weight=torch.cat((ckpt_file['state_dict']['model.diffusion_model.input_blocks.0.0.weight'],zero_data),dim=1) +ckpt_file['state_dict']['model.diffusion_model.input_blocks.0.0.weight']=new_weight +torch.save(ckpt_file,"pretrained_models/sd-v1-4-modified-9channel.ckpt") \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..acb122156723518f4e3156a6664b15252158e14c --- /dev/null +++ b/setup.py @@ -0,0 +1,13 @@ +from setuptools import setup, find_packages + +setup( + name='Magic-FU', + version='0.0.1', + description='', + packages=find_packages(), + install_requires=[ + 'torch', + 'numpy', + 'tqdm', + ], +) diff --git a/train.sh b/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..8bd454f16fdaff94511c4808170201e19ba6eca0 --- /dev/null +++ b/train.sh @@ -0,0 +1,6 @@ +python -u main.py \ +--logdir models/Paint-by-Example \ +--pretrained_model pretrained_models/sd-v1-4-modified-9channel.ckpt \ +--base configs/collage_mix_train.yaml \ +--scale_lr False \ +--name collage_mix_magic_fixup