RishabA commited on
Commit
ddc10aa
·
verified ·
1 Parent(s): 73319a4

Upload 4 files

Browse files
Files changed (4) hide show
  1. Conditioned_CelebA_Latent_Diffusion.ipynb +0 -0
  2. app.py +91 -0
  3. model.py +1823 -0
  4. requirements.txt +7 -0
Conditioned_CelebA_Latent_Diffusion.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from model import UNet, VQVAE, sample_ddpm_inference
4
+ from huggingface_hub import hf_hub_download
5
+ import json
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ config_path = hf_hub_download(
10
+ repo_id="RishabA/celeba-cond-ddpm", filename="config.json"
11
+ )
12
+ with open(config_path, "r") as f:
13
+ config = json.load(f)
14
+
15
+ # Download checkpoint files. Adjust file paths if needed.
16
+ ldm_ckpt_path = hf_hub_download(
17
+ repo_id="RishabA/celeba-cond-ddpm", filename="celebhq/ddpm_ckpt_class_cond.pth"
18
+ )
19
+ vae_ckpt_path = hf_hub_download(
20
+ repo_id="RishabA/celeba-cond-ddpm", filename="celebhq/vqvae_autoencoder_ckpt.pth"
21
+ )
22
+
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ unet = UNet(config["autoencoder_params"]["z_channels"], config["ldm_params"]).to(device)
25
+ vae = VQVAE(
26
+ config["dataset_params"]["image_channels"], config["autoencoder_params"]
27
+ ).to(device)
28
+
29
+ # Load the pretrained weights
30
+ unet_state = torch.load(ldm_ckpt_path, map_location=device)
31
+ unet.load_state_dict(unet_state["model_state_dict"])
32
+
33
+ vae_state = torch.load(vae_ckpt_path, map_location=device)
34
+ vae.load_state_dict(vae_state["model_state_dict"])
35
+
36
+ unet.eval()
37
+ vae.eval()
38
+
39
+ print("Model and checkpoints loaded successfully!")
40
+ print(unet)
41
+ print(vae)
42
+
43
+
44
+ def generate_image(text_prompt, mask_upload):
45
+ """
46
+ text_prompt: A text prompt provided by the user.
47
+ mask_upload: Either a PIL image (uploaded) or None.
48
+ guidance_scale: Float slider setting for classifier-free guidance.
49
+ """
50
+ return sample_ddpm_inference(unet, vae, text_prompt, mask_upload, device)
51
+
52
+
53
+ css_str = """
54
+ body {
55
+ background-color: #f7f7f7;
56
+ }
57
+ .title {
58
+ font-size: 48px;
59
+ text-align: center;
60
+ margin-top: 20px;
61
+ }
62
+ .description {
63
+ font-size: 20px;
64
+ text-align: center;
65
+ margin-bottom: 40px;
66
+ }
67
+ """
68
+
69
+ with gr.Blocks(css=css_str) as demo:
70
+ gr.Markdown("<div class='title'>Conditioned Latent Diffusion with CelebA</div>")
71
+ gr.Markdown(
72
+ "<div class='description'>Enter a text prompt and (optionally) upload a mask image for conditioning; the model will generate an image accordingly.</div>"
73
+ )
74
+ with gr.Row():
75
+ text_input = gr.Textbox(
76
+ label="Text Prompt",
77
+ lines=2,
78
+ placeholder="E.g., 'He is a man with brown hair.'",
79
+ )
80
+ mask_input = gr.Image(type="pil", label="Optional Mask for Conditioning")
81
+ generate_button = gr.Button("Generate Image")
82
+ output_image = gr.Image(label="Generated Image", type="pil")
83
+
84
+ generate_button.click(
85
+ fn=generate_image,
86
+ inputs=[text_input, mask_input],
87
+ outputs=output_image,
88
+ )
89
+
90
+ if __name__ == "__main__":
91
+ demo.launch(share=True)
model.py ADDED
@@ -0,0 +1,1823 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import glob
5
+ import pickle
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms as transforms
10
+ from torch.optim import Adam
11
+ from torchvision.utils import make_grid
12
+ from PIL import Image
13
+ from transformers import (
14
+ DistilBertModel,
15
+ DistilBertTokenizer,
16
+ CLIPTokenizer,
17
+ CLIPTextModel,
18
+ )
19
+
20
+ dataset_params = {
21
+ "image_path": "data/CelebAMask-HQ",
22
+ "image_channels": 3,
23
+ "image_size": 256,
24
+ "name": "celebhq",
25
+ }
26
+
27
+ diffusion_params = {
28
+ "num_timesteps": 1000,
29
+ "beta_start": 0.00085,
30
+ "beta_end": 0.012,
31
+ }
32
+
33
+ ldm_params = {
34
+ "down_channels": [256, 384, 512, 768],
35
+ "mid_channels": [768, 512],
36
+ "down_sample": [True, True, True],
37
+ "attn_down": [True, True, True], # Attention in the DownBlock and UpBlock of VQ-VAE
38
+ "time_emb_dim": 512,
39
+ "norm_channels": 32,
40
+ "num_heads": 16,
41
+ "conv_out_channels": 128,
42
+ "num_down_layers": 2,
43
+ "num_mid_layers": 2,
44
+ "num_up_layers": 2,
45
+ "condition_config": {
46
+ "condition_types": ["text", "image"],
47
+ "text_condition_config": {
48
+ "text_embed_model": "clip",
49
+ "train_text_embed_model": False,
50
+ "text_embed_dim": 512, # Each token should map to text_embed_dim sized vector
51
+ "cond_drop_prob": 0.1, # Probability of dropping conditioning during training to allow the model to generate images without conditioning as well
52
+ },
53
+ "image_condition_config": {
54
+ "image_condition_input_channels": 18, # CelebA has 18 classes excluding background
55
+ "image_condition_output_channels": 3,
56
+ "image_condition_h": 512, # Mask height
57
+ "image_condition_w": 512, # Mask width
58
+ "cond_drop_prob": 0.1, # Probability of dropping conditioning during training to allow the model to generate images without conditioning as well
59
+ },
60
+ },
61
+ }
62
+
63
+ autoencoder_params = {
64
+ "z_channels": 4,
65
+ "codebook_size": 8192,
66
+ "down_channels": [64, 128, 256, 256],
67
+ "mid_channels": [256, 256],
68
+ "down_sample": [True, True, True],
69
+ "attn_down": [
70
+ False,
71
+ False,
72
+ False,
73
+ ], # No attention in the DownBlock and UpBlock of VQ-VAE
74
+ "norm_channels": 32,
75
+ "num_heads": 4,
76
+ "num_down_layers": 2,
77
+ "num_mid_layers": 2,
78
+ "num_up_layers": 2,
79
+ }
80
+
81
+ train_params = {
82
+ "seed": 1111,
83
+ "task_name": "celebhq", # Folder to save models and images to
84
+ "ldm_batch_size": 16,
85
+ "autoencoder_batch_size": 4,
86
+ "disc_start": 15000,
87
+ "disc_weight": 0.5,
88
+ "codebook_weight": 1,
89
+ "commitment_beta": 0.2,
90
+ "perceptual_weight": 1,
91
+ "kl_weight": 0.000005,
92
+ "ldm_epochs": 100,
93
+ "autoencoder_epochs": 20,
94
+ "num_samples": 1,
95
+ "num_grid_rows": 1,
96
+ "ldm_lr": 0.000005,
97
+ "autoencoder_lr": 0.00001,
98
+ "autoencoder_acc_steps": 4,
99
+ "autoencoder_img_save_steps": 64,
100
+ "save_latents": True,
101
+ "cf_guidance_scale": 1.0,
102
+ "vqvae_latent_dir_name": "vqvae_latents",
103
+ "ldm_ckpt_name": "ddpm_ckpt_class_cond.pth",
104
+ "vqvae_autoencoder_ckpt_name": "vqvae_autoencoder_ckpt.pth",
105
+ }
106
+
107
+
108
+ def get_config_value(config, key, default_value):
109
+ return config[key] if key in config else default_value
110
+
111
+
112
+ def spatial_average(in_tens, keepdim=True):
113
+ return in_tens.mean([2, 3], keepdim=keepdim)
114
+
115
+
116
+ class LinearNoiseScheduler:
117
+ def __init__(self, num_timesteps, beta_start, beta_end):
118
+ self.num_timesteps = num_timesteps
119
+ self.beta_start = beta_start
120
+ self.beta_end = beta_end
121
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_timesteps) ** 2
122
+ self.alphas = 1.0 - self.betas
123
+ self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
124
+ self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
125
+ self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)
126
+
127
+ def add_noise(self, original, noise, t):
128
+ # original: (batch_size, c, h, w), t: tensor of timesteps (batch_size,)
129
+ batch_size = original.shape[0]
130
+ sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].view(
131
+ batch_size, 1, 1, 1
132
+ )
133
+ sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(
134
+ original.device
135
+ )[t].view(batch_size, 1, 1, 1)
136
+ return sqrt_alpha_cum_prod * original + sqrt_one_minus_alpha_cum_prod * noise
137
+
138
+ def sample_prev_timestep(self, xt, noise_pred, t):
139
+ batch_size = xt.shape[0]
140
+ alpha_cum_prod_t = self.alpha_cum_prod.to(xt.device)[t].view(
141
+ batch_size, 1, 1, 1
142
+ )
143
+ sqrt_one_minus_alpha_cum_prod_t = self.sqrt_one_minus_alpha_cum_prod.to(
144
+ xt.device
145
+ )[t].view(batch_size, 1, 1, 1)
146
+ x0 = (xt - sqrt_one_minus_alpha_cum_prod_t * noise_pred) / torch.sqrt(
147
+ alpha_cum_prod_t
148
+ )
149
+ x0 = torch.clamp(x0, -1.0, 1.0)
150
+ betas_t = self.betas.to(xt.device)[t].view(batch_size, 1, 1, 1)
151
+ mean = (
152
+ xt - betas_t / sqrt_one_minus_alpha_cum_prod_t * noise_pred
153
+ ) / torch.sqrt(self.alphas.to(xt.device)[t].view(batch_size, 1, 1, 1))
154
+ if t[0] == 0:
155
+ return mean, x0
156
+ else:
157
+ prev_alpha_cum_prod = self.alpha_cum_prod.to(xt.device)[
158
+ (t - 1).clamp(min=0)
159
+ ].view(batch_size, 1, 1, 1)
160
+ variance = (1 - prev_alpha_cum_prod) / (1 - alpha_cum_prod_t) * betas_t
161
+ sigma = variance.sqrt()
162
+ z = torch.randn_like(xt)
163
+ return mean + sigma * z, x0
164
+
165
+
166
+ def get_tokenizer_and_model(model_type, device, eval_mode=True):
167
+ assert model_type in (
168
+ "bert",
169
+ "clip",
170
+ ), "Text model can only be one of 'clip' or 'bert'"
171
+ if model_type == "bert":
172
+ text_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
173
+ text_model = DistilBertModel.from_pretrained("distilbert-base-uncased").to(
174
+ device
175
+ )
176
+ else:
177
+ text_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch16")
178
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch16").to(
179
+ device
180
+ )
181
+ if eval_mode:
182
+ text_model.eval()
183
+ return text_tokenizer, text_model
184
+
185
+
186
+ def get_text_representation(text, text_tokenizer, text_model, device, max_length=77):
187
+ token_output = text_tokenizer(
188
+ text,
189
+ truncation=True,
190
+ padding="max_length",
191
+ return_attention_mask=True,
192
+ max_length=max_length,
193
+ )
194
+ tokens_tensor = torch.tensor(token_output["input_ids"]).to(device)
195
+ mask_tensor = torch.tensor(token_output["attention_mask"]).to(device)
196
+ text_embed = text_model(tokens_tensor, attention_mask=mask_tensor).last_hidden_state
197
+ return text_embed
198
+
199
+
200
+ def get_time_embedding(time_steps, temb_dim):
201
+ """
202
+ Convert time steps tensor into an embedding using the sinusoidal time embedding formula
203
+ time_steps: 1D tensor of length batch size
204
+ temb_dim: Dimension of the embedding
205
+ """
206
+ assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
207
+
208
+ # factor = 10000^(2i/d_model)
209
+ factor = 10000 ** (
210
+ (
211
+ torch.arange(
212
+ start=0,
213
+ end=temb_dim // 2,
214
+ dtype=torch.float32,
215
+ device=time_steps.device,
216
+ )
217
+ / (temb_dim // 2)
218
+ )
219
+ )
220
+
221
+ t_emb = time_steps.unsqueeze(dim=-1).repeat(1, temb_dim // 2) / factor
222
+ t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
223
+
224
+ return t_emb # (batch_size, temb_dim)
225
+
226
+
227
+ class DownBlock(nn.Module):
228
+ """
229
+ Down conv block with attention.
230
+ 1. Resnet block with time embedding
231
+ 2. Attention block
232
+ 3. Downsample
233
+
234
+ in_channels: Number of channels in the input feature map.
235
+ out_channels: Number of channels produced by this block.
236
+ t_emb_dim: Dimension of the time embedding. Only use for UNet for Diffusion. In an AutoEncoder, set it to None.
237
+ down_sample: Whether to apply downsampling at the end.
238
+ num_heads: Number of attention heads (used if attention is enabled).
239
+ num_layers: How many sub-blocks to apply in sequence.
240
+ attn: Whether to apply self-attention
241
+ norm_channels: Number of groups for GroupNorm.
242
+ cross_attn: Whether to apply cross-attention.
243
+ context_dim: If performing cross-attention, provide a context_dim for extra conditioning context.
244
+ """
245
+
246
+ def __init__(
247
+ self,
248
+ in_channels,
249
+ out_channels,
250
+ t_emb_dim,
251
+ down_sample,
252
+ num_heads,
253
+ num_layers,
254
+ attn,
255
+ norm_channels,
256
+ cross_attn=False,
257
+ context_dim=None,
258
+ ):
259
+ super().__init__()
260
+
261
+ self.num_layers = num_layers
262
+ self.down_sample = down_sample
263
+ self.attn = attn
264
+ self.context_dim = context_dim
265
+ self.cross_attn = cross_attn
266
+ self.t_emb_dim = t_emb_dim
267
+
268
+ self.resnet_conv_first = nn.ModuleList(
269
+ [
270
+ nn.Sequential(
271
+ nn.GroupNorm(
272
+ norm_channels, in_channels if i == 0 else out_channels
273
+ ), # Normalizes over channels. For the first sub-block, the in_channels=in_channels, else out_channels
274
+ nn.SiLU(),
275
+ nn.Conv2d(
276
+ in_channels=(in_channels if i == 0 else out_channels),
277
+ out_channels=out_channels,
278
+ kernel_size=3,
279
+ stride=1,
280
+ padding=1,
281
+ ), # (batch_size, c, h, w) -> (batch_size, out_channels, h, w)
282
+ )
283
+ for i in range(num_layers)
284
+ ]
285
+ )
286
+
287
+ # Only add the time embedding for diffusion and not AutoEncoder
288
+ if self.t_emb_dim is not None:
289
+ self.t_emb_layers = nn.ModuleList(
290
+ [
291
+ nn.Sequential(
292
+ nn.SiLU(),
293
+ nn.Linear(
294
+ in_features=self.t_emb_dim, out_features=out_channels
295
+ ), # (batch_size, t_emb_dim) -> (batch_size, out_channels)
296
+ )
297
+ for i in range(num_layers)
298
+ ]
299
+ )
300
+
301
+ self.resnet_conv_second = nn.ModuleList(
302
+ [
303
+ nn.Sequential(
304
+ nn.GroupNorm(norm_channels, out_channels),
305
+ nn.SiLU(),
306
+ nn.Conv2d(
307
+ in_channels=out_channels,
308
+ out_channels=out_channels,
309
+ kernel_size=3,
310
+ stride=1,
311
+ padding=1,
312
+ ), # (batch_size, out_channels, h, w) -> (batch_size, out_channels, h, w)
313
+ )
314
+ for i in range(num_layers)
315
+ ]
316
+ )
317
+
318
+ self.residual_input_conv = nn.ModuleList(
319
+ [
320
+ nn.Conv2d(
321
+ in_channels=(in_channels if i == 0 else out_channels),
322
+ out_channels=out_channels,
323
+ kernel_size=1,
324
+ stride=1,
325
+ padding=0,
326
+ ) # (batch_size, in_channels, h, w) -> (batch_size, out_channels, h, w)
327
+ for i in range(num_layers)
328
+ ]
329
+ )
330
+
331
+ if self.attn:
332
+ self.attention_norms = nn.ModuleList(
333
+ [nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)]
334
+ )
335
+
336
+ self.attentions = nn.ModuleList(
337
+ [
338
+ nn.MultiheadAttention(
339
+ embed_dim=out_channels, num_heads=num_heads, batch_first=True
340
+ )
341
+ for i in range(num_layers)
342
+ ]
343
+ )
344
+
345
+ # Cross attention for text conditioning
346
+ if self.cross_attn:
347
+ assert (
348
+ context_dim is not None
349
+ ), "Context Dimension must be passed for cross attention"
350
+
351
+ self.cross_attention_norms = nn.ModuleList(
352
+ [nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)]
353
+ )
354
+
355
+ self.cross_attentions = nn.ModuleList(
356
+ [
357
+ nn.MultiheadAttention(
358
+ embed_dim=out_channels, num_heads=num_heads, batch_first=True
359
+ )
360
+ for i in range(num_layers)
361
+ ]
362
+ )
363
+
364
+ self.context_proj = nn.ModuleList(
365
+ [
366
+ nn.Linear(in_features=context_dim, out_features=out_channels)
367
+ for i in range(num_layers)
368
+ ]
369
+ )
370
+
371
+ # Down sample by a factor of 2
372
+ self.down_sample_conv = (
373
+ nn.Conv2d(
374
+ in_channels=out_channels,
375
+ out_channels=out_channels,
376
+ kernel_size=4,
377
+ stride=2,
378
+ padding=1,
379
+ )
380
+ if self.down_sample
381
+ else nn.Identity()
382
+ ) # (batch_size, out_channels, h / 2, w / 2)
383
+
384
+ def forward(self, x, t_emb=None, context=None):
385
+ out = x
386
+ for i in range(self.num_layers):
387
+ # Resnet block of UNET
388
+ resnet_input = out # (batch_size, c, h, w)
389
+
390
+ out = self.resnet_conv_first[i](out) # (batch_size, out_channels, h, w)
391
+
392
+ # Only add the time embedding for diffusion and not AutoEncoder
393
+ if self.t_emb_dim is not None:
394
+ # Add the embeddings for timesteps - (batch_size, t_emb_dim) -> (batch_size, out_channels, 1, 1)
395
+ out = out + self.t_emb_layers[i](t_emb).unsqueeze(dim=-1).unsqueeze(
396
+ dim=-1
397
+ ) # (batch_size, out_channels, h, w)
398
+
399
+ out = self.resnet_conv_second[i](
400
+ out
401
+ ) # (batch_size, out_channels, h, w) -> (batch_size, out_channels, h, w)
402
+
403
+ # Residual Connection
404
+ out = out + self.residual_input_conv[i](
405
+ resnet_input
406
+ ) # (batch_size, out_channels, h, w)
407
+
408
+ # Only do for Diffusion and not for AutoEncoder
409
+ if self.attn:
410
+ # Attention block of UNET
411
+ batch_size, channels, h, w = (
412
+ out.shape
413
+ ) # (batch_size, out_channels, h, w)
414
+
415
+ in_attn = out.reshape(
416
+ batch_size, channels, h * w
417
+ ) # (batch_size, out_channels, h * w)
418
+ in_attn = self.attention_norms[i](in_attn)
419
+ in_attn = in_attn.transpose(1, 2) # (batch_size, h * w, out_channels)
420
+
421
+ # Self-Attention
422
+ out_attn, attn_weights = self.attentions[i](in_attn, in_attn, in_attn)
423
+ out_attn = out_attn.transpose(1, 2).reshape(
424
+ batch_size, channels, h, w
425
+ ) # (batch_size, out_channels h, w)
426
+
427
+ # Skip connection
428
+ out = out + out_attn # (batch_size, out_channels h, w)
429
+
430
+ if self.cross_attn:
431
+ assert (
432
+ context is not None
433
+ ), "context cannot be None if cross attention layers are used"
434
+
435
+ batch_size, channels, h, w = (
436
+ out.shape
437
+ ) # (batch_size, out_channels, h, w)
438
+
439
+ in_attn = out.reshape(
440
+ batch_size, channels, h * w
441
+ ) # (batch_size, out_channels, h * w)
442
+ in_attn = self.cross_attention_norms[i](in_attn)
443
+ in_attn = in_attn.transpose(1, 2) # (batch_size, h * w, out_channels)
444
+
445
+ assert (
446
+ context.shape[0] == x.shape[0]
447
+ and context.shape[-1] == self.context_dim
448
+ ) # Make sure the batch_size and context_dim match with the model's parameters
449
+ context_proj = self.context_proj[i](
450
+ context
451
+ ) # (batch_size, seq_len, context_dim) -> (batch_size, seq_len, out_channels)
452
+
453
+ # Cross-Attention
454
+ out_attn, attn_weights = self.cross_attentions[i](
455
+ in_attn, context_proj, context_proj
456
+ ) # (batch_size, h * w, out_channels)
457
+ out_attn = out_attn.transpose(1, 2).reshape(
458
+ batch_size, channels, h, w
459
+ ) # (batch_size, out_channels, h, w)
460
+
461
+ # Skip Connection
462
+ out = out + out_attn # (batch_size, out_channels, h, w)
463
+
464
+ # Downsampling
465
+ out = self.down_sample_conv(out) # (batch_size, out_channels, h / 2, w / 2)
466
+ return out
467
+
468
+
469
+ class MidBlock(nn.Module):
470
+ """
471
+ Mid conv block with attention.
472
+ 1. Resnet block with time embedding
473
+ 2. Attention block
474
+ 3. Resnet block with time embedding
475
+
476
+ in_channels: Number of channels in the input feature map.
477
+ out_channels: Number of channels produced by this block.
478
+ t_emb_dim: Dimension of the time embedding. Only use for UNet for Diffusion. In an AutoEncoder, set it to None.
479
+ num_heads: Number of attention heads (used if attention is enabled).
480
+ num_layers: How many sub-blocks to apply in sequence.
481
+ norm_channels: Number of groups for GroupNorm.
482
+ cross_attn: Whether to apply cross-attention.
483
+ context_dim: If performing cross-attention, provide a context_dim for extra conditioning context.
484
+ """
485
+
486
+ def __init__(
487
+ self,
488
+ in_channels,
489
+ out_channels,
490
+ t_emb_dim,
491
+ num_heads,
492
+ num_layers,
493
+ norm_channels,
494
+ cross_attn=None,
495
+ context_dim=None,
496
+ ):
497
+ super().__init__()
498
+
499
+ self.num_layers = num_layers
500
+ self.t_emb_dim = t_emb_dim
501
+ self.context_dim = context_dim
502
+ self.cross_attn = cross_attn
503
+
504
+ self.resnet_conv_first = nn.ModuleList(
505
+ [
506
+ nn.Sequential(
507
+ nn.GroupNorm(
508
+ norm_channels, in_channels if i == 0 else out_channels
509
+ ), # Normalizes over channels. For the first sub-block, the in_channels=in_channels, else out_channels
510
+ nn.SiLU(),
511
+ nn.Conv2d(
512
+ in_channels=(in_channels if i == 0 else out_channels),
513
+ out_channels=out_channels,
514
+ kernel_size=3,
515
+ stride=1,
516
+ padding=1,
517
+ ), # (batch_size, c, h, w) -> (batch_size, out_channels, h, w)
518
+ )
519
+ for i in range(num_layers + 1)
520
+ ]
521
+ )
522
+
523
+ # Only add the time embedding for diffusion and not AutoEncoder
524
+ if self.t_emb_dim is not None:
525
+ self.t_emb_layers = nn.ModuleList(
526
+ [
527
+ nn.Sequential(
528
+ nn.SiLU(),
529
+ nn.Linear(
530
+ in_features=self.t_emb_dim, out_features=out_channels
531
+ ), # (batch_size, t_emb_dim) -> (batch_size, out_channels)
532
+ )
533
+ for i in range(num_layers + 1)
534
+ ]
535
+ )
536
+
537
+ self.resnet_conv_second = nn.ModuleList(
538
+ [
539
+ nn.Sequential(
540
+ nn.GroupNorm(norm_channels, out_channels),
541
+ nn.SiLU(),
542
+ nn.Conv2d(
543
+ in_channels=out_channels,
544
+ out_channels=out_channels,
545
+ kernel_size=3,
546
+ stride=1,
547
+ padding=1,
548
+ ), # (batch_size, out_channels, h, w) -> (batch_size, out_channels, h, w)
549
+ )
550
+ for i in range(num_layers + 1)
551
+ ]
552
+ )
553
+
554
+ self.residual_input_conv = nn.ModuleList(
555
+ [
556
+ nn.Conv2d(
557
+ in_channels=(in_channels if i == 0 else out_channels),
558
+ out_channels=out_channels,
559
+ kernel_size=1,
560
+ stride=1,
561
+ padding=0,
562
+ ) # (batch_size, in_channels, h, w) -> (batch_size, out_channels, h, w)
563
+ for i in range(num_layers + 1)
564
+ ]
565
+ )
566
+
567
+ self.attention_norms = nn.ModuleList(
568
+ [nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)]
569
+ )
570
+
571
+ self.attentions = nn.ModuleList(
572
+ [
573
+ nn.MultiheadAttention(
574
+ embed_dim=out_channels, num_heads=num_heads, batch_first=True
575
+ )
576
+ for i in range(num_layers)
577
+ ]
578
+ )
579
+
580
+ # Cross attention for text conditioning
581
+ if self.cross_attn:
582
+ assert (
583
+ context_dim is not None
584
+ ), "Context Dimension must be passed for cross attention"
585
+
586
+ self.cross_attention_norms = nn.ModuleList(
587
+ [nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)]
588
+ )
589
+
590
+ self.cross_attentions = nn.ModuleList(
591
+ [
592
+ nn.MultiheadAttention(
593
+ embed_dim=out_channels, num_heads=num_heads, batch_first=True
594
+ )
595
+ for i in range(num_layers)
596
+ ]
597
+ )
598
+
599
+ self.context_proj = nn.ModuleList(
600
+ [
601
+ nn.Linear(in_features=context_dim, out_features=out_channels)
602
+ for i in range(num_layers)
603
+ ]
604
+ )
605
+
606
+ def forward(self, x, t_emb=None, context=None):
607
+ out = x
608
+
609
+ # First ResNet block
610
+ resnet_input = out # (batch_size, c, h, w)
611
+ out = self.resnet_conv_first[0](out) # (batch_size, out_channels, h, w)
612
+
613
+ # Only add the time embedding for diffusion and not AutoEncoder
614
+ if self.t_emb_dim is not None:
615
+ # Add the embeddings for timesteps - (batch_size, t_emb_dim) -> (batch_size, out_channels, 1, 1)
616
+ out = out + self.t_emb_layers[0](t_emb).unsqueeze(dim=-1).unsqueeze(
617
+ dim=-1
618
+ ) # (batch_size, out_channels, h, w)
619
+
620
+ out = self.resnet_conv_second[0](
621
+ out
622
+ ) # (batch_size, out_channels, h, w) -> (batch_size, out_channels, h, w)
623
+
624
+ # Residual Connection
625
+ out = out + self.residual_input_conv[0](
626
+ resnet_input
627
+ ) # (batch_size, out_channels, h, w)
628
+
629
+ for i in range(self.num_layers):
630
+ # Attention Block
631
+ batch_size, channels, h, w = out.shape # (batch_size, out_channels, h, w)
632
+
633
+ # Do for both Diffusion and AutoEncoder
634
+ in_attn = out.reshape(
635
+ batch_size, channels, h * w
636
+ ) # (batch_size, out_channels, h * w)
637
+ in_attn = self.attention_norms[i](in_attn)
638
+ in_attn = in_attn.transpose(1, 2) # (batch_size, h * w, out_channels)
639
+
640
+ # Self-Attention
641
+ out_attn, attn_weights = self.attentions[i](in_attn, in_attn, in_attn)
642
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
643
+
644
+ # Skip connection
645
+ out = out + out_attn # (batch_size, out_channels h, w)
646
+
647
+ if self.cross_attn:
648
+ assert (
649
+ context is not None
650
+ ), "context cannot be None if cross attention layers are used"
651
+ batch_size, channels, h, w = out.shape
652
+
653
+ in_attn = out.reshape(
654
+ batch_size, channels, h * w
655
+ ) # (batch_size, out_channels, h * w)
656
+ in_attn = self.cross_attention_norms[i](in_attn)
657
+ in_attn = in_attn.transpose(1, 2) # (batch_size, h * w, out_channels)
658
+
659
+ assert (
660
+ context.shape[0] == x.shape[0]
661
+ and context.shape[-1] == self.context_dim
662
+ ) # Make sure the batch_size and context_dim match with the model's parameters
663
+ context_proj = self.context_proj[i](
664
+ context
665
+ ) # (batch_size, seq_len, context_dim) -> (batch_size, seq_len, context_dim)
666
+
667
+ # Cross-Attention
668
+ out_attn, attn_weights = self.cross_attentions[i](
669
+ in_attn, context_proj, context_proj
670
+ )
671
+ out_attn = out_attn.transpose(1, 2).reshape(
672
+ batch_size, channels, h, w
673
+ ) # (batch_size, out_channels, h, w)
674
+
675
+ # Skip Connection
676
+ out = out + out_attn # (batch_size, out_channels h, w)
677
+
678
+ # Resnet Block
679
+ resnet_input = out
680
+ out = self.resnet_conv_first[i + 1](
681
+ out
682
+ ) # (batch_size, out_channels h, w) -> (batch_size, out_channels h, w)
683
+
684
+ # Only add the time embedding for diffusion and not AutoEncoder
685
+ if self.t_emb_dim is not None:
686
+ # Add the embeddings for timesteps - (batch_size, t_emb_dim) -> (batch_size, out_channels, 1, 1)
687
+ out = out + self.t_emb_layers[i + 1](t_emb).unsqueeze(dim=-1).unsqueeze(
688
+ dim=-1
689
+ ) # (batch_size, out_channels h, w)
690
+
691
+ out = self.resnet_conv_second[i + 1](
692
+ out
693
+ ) # (batch_size, out_channels h, w) -> (batch_size, out_channels h, w)
694
+
695
+ # Residual Connection
696
+ out = out + self.residual_input_conv[i + 1](
697
+ resnet_input
698
+ ) # (batch_size, out_channels, h, w)
699
+
700
+ return out
701
+
702
+
703
+ class UpBlock(nn.Module):
704
+ """
705
+ Up conv block with attention.
706
+ 1. Upsample
707
+ 1. Concatenate Down block output
708
+ 2. Resnet block with time embedding
709
+ 3. Attention Block
710
+
711
+ in_channels: Number of channels in the input feature map.
712
+ out_channels: Number of channels produced by this block.
713
+ t_emb_dim: Dimension of the time embedding. Only use for UNet for Diffusion. In an AutoEncoder, set it to None.
714
+ up_sample: Whether to apply upsampling at the end.
715
+ num_heads: Number of attention heads (used if attention is enabled).
716
+ num_layers: How many sub-blocks to apply in sequence.
717
+ attn: Whether to apply self-attention
718
+ norm_channels: Number of groups for GroupNorm.
719
+ """
720
+
721
+ def __init__(
722
+ self,
723
+ in_channels,
724
+ out_channels,
725
+ t_emb_dim,
726
+ up_sample,
727
+ num_heads,
728
+ num_layers,
729
+ attn,
730
+ norm_channels,
731
+ ):
732
+ super().__init__()
733
+
734
+ self.num_layers = num_layers
735
+ self.up_sample = up_sample
736
+ self.t_emb_dim = t_emb_dim
737
+ self.attn = attn
738
+
739
+ # Upsample by a factor of 2
740
+ self.up_sample_conv = (
741
+ nn.ConvTranspose2d(
742
+ in_channels=in_channels,
743
+ out_channels=in_channels,
744
+ kernel_size=4,
745
+ stride=2,
746
+ padding=1,
747
+ )
748
+ if self.up_sample
749
+ else nn.Identity()
750
+ ) # (batch_size, c, h * 2, w * 2)
751
+
752
+ self.resnet_conv_first = nn.ModuleList(
753
+ [
754
+ nn.Sequential(
755
+ nn.GroupNorm(
756
+ norm_channels, in_channels if i == 0 else out_channels
757
+ ), # Normalizes over channels. For the first sub-block, the in_channels=in_channels, else out_channels
758
+ nn.SiLU(),
759
+ nn.Conv2d(
760
+ in_channels=(in_channels if i == 0 else out_channels),
761
+ out_channels=out_channels,
762
+ kernel_size=3,
763
+ stride=1,
764
+ padding=1,
765
+ ), # (batch_size, c, h, w) -> (batch_size, out_channels, h, w)
766
+ )
767
+ for i in range(num_layers)
768
+ ]
769
+ )
770
+
771
+ # Only add the time embedding for diffusion and not AutoEncoder
772
+ if self.t_emb_dim is not None:
773
+ self.t_emb_layers = nn.ModuleList(
774
+ [
775
+ nn.Sequential(
776
+ nn.SiLU(),
777
+ nn.Linear(
778
+ in_features=self.t_emb_dim, out_features=out_channels
779
+ ), # (batch_size, t_emb_dim) -> (batch_size, out_channels)
780
+ )
781
+ for i in range(num_layers)
782
+ ]
783
+ )
784
+
785
+ self.resnet_conv_second = nn.ModuleList(
786
+ [
787
+ nn.Sequential(
788
+ nn.GroupNorm(norm_channels, out_channels),
789
+ nn.SiLU(),
790
+ nn.Conv2d(
791
+ in_channels=out_channels,
792
+ out_channels=out_channels,
793
+ kernel_size=3,
794
+ stride=1,
795
+ padding=1,
796
+ ), # (batch_size, out_channels, h, w) -> (batch_size, out_channels, h, w)
797
+ )
798
+ for i in range(num_layers)
799
+ ]
800
+ )
801
+
802
+ self.residual_input_conv = nn.ModuleList(
803
+ [
804
+ nn.Conv2d(
805
+ in_channels=(in_channels if i == 0 else out_channels),
806
+ out_channels=out_channels,
807
+ kernel_size=1,
808
+ stride=1,
809
+ padding=0,
810
+ ) # (batch_size, in_channels, h, w) -> (batch_size, out_channels, h, w)
811
+ for i in range(num_layers)
812
+ ]
813
+ )
814
+
815
+ if self.attn:
816
+ self.attention_norms = nn.ModuleList(
817
+ [nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)]
818
+ )
819
+
820
+ self.attentions = nn.ModuleList(
821
+ [
822
+ nn.MultiheadAttention(
823
+ embed_dim=out_channels, num_heads=num_heads, batch_first=True
824
+ )
825
+ for i in range(num_layers)
826
+ ]
827
+ )
828
+
829
+ def forward(self, x, out_down=None, t_emb=None):
830
+ # x shape: (batch_size, c, h, w)
831
+
832
+ # Upsample
833
+ x = self.up_sample_conv(
834
+ x
835
+ ) # (batch_size, c, h, w) -> (batch_size, c, h * 2, w * 2)
836
+
837
+ # *Only do for diffusion
838
+ # Concatenate with the output of respective DownBlock
839
+ if out_down is not None:
840
+ x = torch.cat(
841
+ [x, out_down], dim=1
842
+ ) # (batch_size, c, h * 2, w * 2) -> (batch_size, c * 2, h * 2, w * 2)
843
+
844
+ out = x # (batch_size, c, h * 2, w * 2)
845
+
846
+ for i in range(self.num_layers):
847
+ # Resnet block
848
+ resnet_input = out
849
+ out = self.resnet_conv_first[i](
850
+ out
851
+ ) # (batch_size, in_channels, h * 2, w * 2) -> (batch_size, out_channels, h * 2, w * 2)
852
+
853
+ # Only add the time embedding for diffusion and not AutoEncoder
854
+ if self.t_emb_dim is not None:
855
+ # Add the embeddings for timesteps - (batch_size, t_emb_dim) -> (batch_size, out_channels, 1, 1)
856
+ out = out + self.t_emb_layers[i](t_emb).unsqueeze(dim=-1).unsqueeze(
857
+ dim=-1
858
+ ) # (batch_size, out_channels, h * 2, w * 2)
859
+
860
+ out = self.resnet_conv_second[i](
861
+ out
862
+ ) # (batch_size, out_channels, h * 2, w * 2) -> (batch_size, out_channels, h * 2, w * 2)
863
+
864
+ # Residual Connection
865
+ out = out + self.residual_input_conv[i](
866
+ resnet_input
867
+ ) # (batch_size, out_channels, h * 2, w * 2)
868
+
869
+ # Only do for Diffusion and not for AutoEncoder
870
+ if self.attn:
871
+ # Attention block of UNET
872
+ batch_size, channels, h, w = out.shape
873
+
874
+ in_attn = out.reshape(
875
+ batch_size, channels, h * w
876
+ ) # (batch_size, out_channels, h * w * 4)
877
+ in_attn = self.attention_norms[i](in_attn)
878
+ in_attn = in_attn.transpose(
879
+ 1, 2
880
+ ) # (batch_size, h * w * 4, out_channels)
881
+
882
+ # Self-Attention
883
+ out_attn, attn_weights = self.attentions[i](in_attn, in_attn, in_attn)
884
+ out_attn = out_attn.transpose(1, 2).reshape(
885
+ batch_size, channels, h, w
886
+ ) # (batch_size, out_channels h * 2, w * 2)
887
+
888
+ # Skip connection
889
+ out = out + out_attn # (batch_size, out_channels h * 2, w * 2)
890
+
891
+ return out # (batch_size, out_channels h * 2, w * 2)
892
+
893
+
894
+ class UpBlockUNet(nn.Module):
895
+ """
896
+ Up conv block with attention.
897
+ 1. Upsample
898
+ 1. Concatenate Down block output
899
+ 2. Resnet block with time embedding
900
+ 3. Attention Block
901
+
902
+ in_channels: Number of channels in the input feature map. (It is passed in multiplied by 2 for concatenation with DownBlock output)
903
+ out_channels: Number of channels produced by this block.
904
+ t_emb_dim: Dimension of the time embedding. Only use for UNet for Diffusion. In an AutoEncoder, set it to None.
905
+ up_sample: Whether to apply upsampling at the end.
906
+ num_heads: Number of attention heads (used if attention is enabled).
907
+ num_layers: How many sub-blocks to apply in sequence.
908
+ norm_channels: Number of groups for GroupNorm.
909
+ cross_attn: Whether to apply cross-attention.
910
+ context_dim: If performing cross-attention, provide a context_dim for extra conditioning context.
911
+ """
912
+
913
+ def __init__(
914
+ self,
915
+ in_channels,
916
+ out_channels,
917
+ t_emb_dim,
918
+ up_sample,
919
+ num_heads,
920
+ num_layers,
921
+ norm_channels,
922
+ cross_attn=False,
923
+ context_dim=None,
924
+ ):
925
+ super().__init__()
926
+
927
+ self.num_layers = num_layers
928
+ self.up_sample = up_sample
929
+ self.t_emb_dim = t_emb_dim
930
+ self.cross_attn = cross_attn
931
+ self.context_dim = context_dim
932
+
933
+ self.up_sample_conv = (
934
+ nn.ConvTranspose2d(
935
+ in_channels=(in_channels // 2),
936
+ out_channels=(in_channels // 2),
937
+ kernel_size=4,
938
+ stride=2,
939
+ padding=1,
940
+ )
941
+ if self.up_sample
942
+ else nn.Identity()
943
+ ) # (batch_size, in_channels // 2, h * 2, w * 2)
944
+
945
+ self.resnet_conv_first = nn.ModuleList(
946
+ [
947
+ nn.Sequential(
948
+ nn.GroupNorm(
949
+ norm_channels, in_channels if i == 0 else out_channels
950
+ ), # Normalizes over channels. For the first sub-block, the in_channels=in_channels, else out_channels
951
+ nn.SiLU(),
952
+ nn.Conv2d(
953
+ in_channels=(in_channels if i == 0 else out_channels),
954
+ out_channels=out_channels,
955
+ kernel_size=3,
956
+ stride=1,
957
+ padding=1,
958
+ ), # (batch_size, in_channels, h * 2, w. * 2) -> (batch_size, out_channels, h * 2, w * 2) - Starts at in_channels and not in_channels // 2 because of concatenation
959
+ )
960
+ for i in range(num_layers)
961
+ ]
962
+ )
963
+
964
+ # Only add the time embedding if needed for UNET in diffusion
965
+ # Do not add the time embedding in the AutoEncoder
966
+ if self.t_emb_dim is not None:
967
+ self.t_emb_layers = nn.ModuleList(
968
+ [
969
+ nn.Sequential(
970
+ nn.SiLU(),
971
+ nn.Linear(
972
+ in_features=self.t_emb_dim, out_features=out_channels
973
+ ), # (batch_size, t_emb_dim) -> (batch_size, out_channels)
974
+ )
975
+ for i in range(num_layers)
976
+ ]
977
+ )
978
+
979
+ self.resnet_conv_second = nn.ModuleList(
980
+ [
981
+ nn.Sequential(
982
+ nn.GroupNorm(norm_channels, out_channels),
983
+ nn.SiLU(),
984
+ nn.Conv2d(
985
+ in_channels=out_channels,
986
+ out_channels=out_channels,
987
+ kernel_size=3,
988
+ stride=1,
989
+ padding=1,
990
+ ), # (batch_size, out_channels, h * 2, w * 2) -> (batch_size, out_channels, h * 2, w * 2)
991
+ )
992
+ for i in range(num_layers)
993
+ ]
994
+ )
995
+
996
+ self.residual_input_conv = nn.ModuleList(
997
+ [
998
+ nn.Conv2d(
999
+ in_channels=(in_channels if i == 0 else out_channels),
1000
+ out_channels=out_channels,
1001
+ kernel_size=1,
1002
+ stride=1,
1003
+ padding=0,
1004
+ )
1005
+ for i in range(
1006
+ num_layers
1007
+ ) # (batch_size, in_channels, h * 2, w * 2) -> (batch_size, out_channels, h * 2, w * 2)
1008
+ ]
1009
+ )
1010
+
1011
+ self.attention_norms = nn.ModuleList(
1012
+ [nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)]
1013
+ )
1014
+
1015
+ self.attentions = nn.ModuleList(
1016
+ [
1017
+ nn.MultiheadAttention(
1018
+ embed_dim=out_channels, num_heads=num_heads, batch_first=True
1019
+ )
1020
+ for i in range(num_layers)
1021
+ ]
1022
+ )
1023
+
1024
+ # Cross attention for text conditioning
1025
+ if self.cross_attn:
1026
+ assert (
1027
+ context_dim is not None
1028
+ ), "Context Dimension must be passed for cross attention"
1029
+
1030
+ self.cross_attention_norms = nn.ModuleList(
1031
+ [nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)]
1032
+ )
1033
+
1034
+ self.cross_attentions = nn.ModuleList(
1035
+ [
1036
+ nn.MultiheadAttention(
1037
+ embed_dim=out_channels, num_heads=num_heads, batch_first=True
1038
+ )
1039
+ for i in range(num_layers)
1040
+ ]
1041
+ )
1042
+
1043
+ self.context_proj = nn.ModuleList(
1044
+ [
1045
+ nn.Linear(in_features=context_dim, out_features=out_channels)
1046
+ for i in range(num_layers)
1047
+ ]
1048
+ )
1049
+
1050
+ def forward(self, x, out_down=None, t_emb=None, context=None):
1051
+ # x shape: (batch_size, in_channels // 2, h, w)
1052
+
1053
+ # Upsample
1054
+ x = self.up_sample_conv(
1055
+ x
1056
+ ) # (batch_size, in_channels // 2, h, w) -> (batch_size, in_channels // 2, h * 2, w * 2)
1057
+
1058
+ # Concatenate with the output of respective DownBlock
1059
+ if out_down is not None:
1060
+ x = torch.cat(
1061
+ [x, out_down], dim=1
1062
+ ) # (batch_size, in_channels // 2, h * 2, w * 2) -> (batch_size, in_channels, h * 2, w * 2)
1063
+
1064
+ out = x # (batch_size, in_channels, h * 2, w * 2)
1065
+ for i in range(self.num_layers):
1066
+ # Resnet block
1067
+ resnet_input = out
1068
+
1069
+ out = self.resnet_conv_first[i](
1070
+ out
1071
+ ) # (batch_size, in_channels, h * 2, w * 2) -> (batch_size, out_channels, h * 2, w * 2)
1072
+
1073
+ if self.t_emb_dim is not None:
1074
+ # Add the embeddings for timesteps - (batch_size, t_emb_dim) -> (batch_size, out_channels, 1, 1)
1075
+ out = out + self.t_emb_layers[i](t_emb).unsqueeze(dim=-1).unsqueeze(
1076
+ dim=-1
1077
+ ) # (batch_size, out_channels, h * 2, w * 2)
1078
+
1079
+ out = self.resnet_conv_second[i](
1080
+ out
1081
+ ) # (batch_size, out_channels, h * 2, w * 2) -> (batch_size, out_channels, h * 2, w * 2)
1082
+
1083
+ # Residual Connection
1084
+ out = out + self.residual_input_conv[i](
1085
+ resnet_input
1086
+ ) # (batch_size, out_channels, h * 2, w * 2)
1087
+
1088
+ # Attention block of UNET
1089
+ batch_size, channels, h, w = (
1090
+ out.shape
1091
+ ) # (batch_size, out_channels, h * 2, w * 2)
1092
+
1093
+ in_attn = out.reshape(
1094
+ batch_size, channels, h * w
1095
+ ) # (batch_size, out_channels, h * w * 4)
1096
+ in_attn = self.attention_norms[i](in_attn)
1097
+ in_attn = in_attn.transpose(1, 2) # (batch_size, h * w * 4, out_channels)
1098
+
1099
+ # Self-Attention
1100
+ out_attn, attn_weights = self.attentions[i](in_attn, in_attn, in_attn)
1101
+ out_attn = out_attn.transpose(1, 2).reshape(
1102
+ batch_size, channels, h, w
1103
+ ) # (batch_size, out_channels h * 2, w * 2)
1104
+
1105
+ # Skip connection
1106
+ out = out + out_attn # (batch_size, out_channels h * 2, w * 2)
1107
+
1108
+ if self.cross_attn:
1109
+ assert (
1110
+ context is not None
1111
+ ), "context cannot be None if cross attention layers are used"
1112
+ batch_size, channels, h, w = out.shape
1113
+
1114
+ in_attn = out.reshape(
1115
+ batch_size, channels, h * w
1116
+ ) # (batch_size, out_channels, h * w * 4)
1117
+ in_attn = self.cross_attention_norms[i](in_attn)
1118
+ in_attn = in_attn.transpose(
1119
+ 1, 2
1120
+ ) # (batch_size, h * w * 4, out_channels)
1121
+
1122
+ assert (
1123
+ len(context.shape) == 3
1124
+ ), "Context shape does not match batch_size, _, context_dim"
1125
+
1126
+ assert (
1127
+ context.shape[0] == x.shape[0]
1128
+ and context.shape[-1] == self.context_dim
1129
+ ), "Context shape does not match batch_size, _, context_dim" # Make sure the batch_size and context_dim match with the model's parameters
1130
+ context_proj = self.context_proj[i](
1131
+ context
1132
+ ) # (batch_size, seq_len, context_dim) -> (batch_size, seq_len, context_dim)
1133
+
1134
+ # Cross-Attention
1135
+ out_attn, attn_weights = self.cross_attentions[i](
1136
+ in_attn, context_proj, context_proj
1137
+ )
1138
+ out_attn = out_attn.transpose(1, 2).reshape(
1139
+ batch_size, channels, h, w
1140
+ ) # (batch_size, out_channels, h * 2, w * 2)
1141
+
1142
+ # Skip Connection
1143
+ out = out + out_attn # (batch_size, out_channels h * 2, w * 2)
1144
+
1145
+ return out # (batch_size, out_channels h * 2, w * 2)
1146
+
1147
+
1148
+ class VQVAE(nn.Module):
1149
+ def __init__(self, image_channels, model_config):
1150
+ super().__init__()
1151
+
1152
+ self.down_channels = model_config["down_channels"]
1153
+ self.mid_channels = model_config["mid_channels"]
1154
+ self.down_sample = model_config["down_sample"]
1155
+ self.num_down_layers = model_config["num_down_layers"]
1156
+ self.num_mid_layers = model_config["num_mid_layers"]
1157
+ self.num_up_layers = model_config["num_up_layers"]
1158
+
1159
+ # To disable attention in Downblock of Encoder and Upblock of Decoder
1160
+ self.attns = model_config["attn_down"]
1161
+
1162
+ # Latent Dimension
1163
+ self.z_channels = model_config[
1164
+ "z_channels"
1165
+ ] # number of channels in the latent representation
1166
+ self.codebook_size = model_config[
1167
+ "codebook_size"
1168
+ ] # number of discrete code vectors available
1169
+ self.norm_channels = model_config["norm_channels"]
1170
+ self.num_heads = model_config["num_heads"]
1171
+
1172
+ assert self.mid_channels[0] == self.down_channels[-1]
1173
+ assert self.mid_channels[-1] == self.down_channels[-1]
1174
+ assert len(self.down_sample) == len(self.down_channels) - 1
1175
+ assert len(self.attns) == len(self.down_channels) - 1
1176
+
1177
+ # Wherever we downsample in the encoder, use upsampling in the decoder at the corresponding location
1178
+ self.up_sample = list(reversed(self.down_sample))
1179
+
1180
+ # Encoder
1181
+ self.encoder_conv_in = nn.Conv2d(
1182
+ in_channels=image_channels,
1183
+ out_channels=self.down_channels[0],
1184
+ kernel_size=3,
1185
+ stride=1,
1186
+ padding=1,
1187
+ ) # (batch_size, 3, h, w) -> (batch_size, c, h, w)
1188
+
1189
+ # Downblock + Midblock
1190
+ self.encoder_layers = nn.ModuleList([])
1191
+ for i in range(len(self.down_channels) - 1):
1192
+ self.encoder_layers.append(
1193
+ DownBlock(
1194
+ in_channels=self.down_channels[i],
1195
+ out_channels=self.down_channels[i + 1],
1196
+ t_emb_dim=None,
1197
+ down_sample=self.down_sample[i],
1198
+ num_heads=self.num_heads,
1199
+ num_layers=self.num_down_layers,
1200
+ attn=self.attns[i],
1201
+ norm_channels=self.norm_channels,
1202
+ )
1203
+ )
1204
+
1205
+ self.encoder_mids = nn.ModuleList([])
1206
+ for i in range(len(self.mid_channels) - 1):
1207
+ self.encoder_mids.append(
1208
+ MidBlock(
1209
+ in_channels=self.mid_channels[i],
1210
+ out_channels=self.mid_channels[i + 1],
1211
+ t_emb_dim=None,
1212
+ num_heads=self.num_heads,
1213
+ num_layers=self.num_mid_layers,
1214
+ norm_channels=self.norm_channels,
1215
+ )
1216
+ )
1217
+
1218
+ self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1])
1219
+
1220
+ self.encoder_conv_out = nn.Conv2d(
1221
+ in_channels=self.down_channels[-1],
1222
+ out_channels=self.z_channels,
1223
+ kernel_size=3,
1224
+ stride=1,
1225
+ padding=1,
1226
+ ) # (batch_size, z_channels, h', w')
1227
+
1228
+ # Pre Quantization Convolution
1229
+ self.pre_quant_conv = nn.Conv2d(
1230
+ in_channels=self.z_channels,
1231
+ out_channels=self.z_channels,
1232
+ kernel_size=1,
1233
+ stride=1,
1234
+ padding=0,
1235
+ ) # (batch_size, z_channels, h', w')
1236
+
1237
+ # Codebook Vectors
1238
+ self.embedding = nn.Embedding(
1239
+ self.codebook_size, self.z_channels
1240
+ ) # (codebook_size, z_channels)
1241
+
1242
+ # Decoder
1243
+
1244
+ # Post Quantization Convolution
1245
+ self.post_quant_conv = nn.Conv2d(
1246
+ in_channels=self.z_channels,
1247
+ out_channels=self.z_channels,
1248
+ kernel_size=1,
1249
+ stride=1,
1250
+ padding=0,
1251
+ ) # (batch_size, z_channels, h', w')
1252
+
1253
+ self.decoder_conv_in = nn.Conv2d(
1254
+ in_channels=self.z_channels,
1255
+ out_channels=self.mid_channels[-1],
1256
+ kernel_size=3,
1257
+ stride=1,
1258
+ padding=1,
1259
+ ) # (batch_size, c, h', w')
1260
+
1261
+ # Midblock + Upblock
1262
+ self.decoder_mids = nn.ModuleList([])
1263
+ for i in reversed(range(1, len(self.mid_channels))):
1264
+ self.decoder_mids.append(
1265
+ MidBlock(
1266
+ in_channels=self.mid_channels[i],
1267
+ out_channels=self.mid_channels[i - 1],
1268
+ t_emb_dim=None,
1269
+ num_heads=self.num_heads,
1270
+ num_layers=self.num_mid_layers,
1271
+ norm_channels=self.norm_channels,
1272
+ )
1273
+ )
1274
+
1275
+ self.decoder_layers = nn.ModuleList([])
1276
+ for i in reversed(range(1, len(self.down_channels))):
1277
+ self.decoder_layers.append(
1278
+ UpBlock(
1279
+ in_channels=self.down_channels[i],
1280
+ out_channels=self.down_channels[i - 1],
1281
+ t_emb_dim=None,
1282
+ up_sample=self.down_sample[i - 1],
1283
+ num_heads=self.num_heads,
1284
+ num_layers=self.num_up_layers,
1285
+ attn=self.attns[i - 1],
1286
+ norm_channels=self.norm_channels,
1287
+ )
1288
+ )
1289
+
1290
+ self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0])
1291
+
1292
+ self.decoder_conv_out = nn.Conv2d(
1293
+ in_channels=self.down_channels[0],
1294
+ out_channels=image_channels,
1295
+ kernel_size=3,
1296
+ stride=1,
1297
+ padding=1,
1298
+ ) # (batch_size, c, h, w)
1299
+
1300
+ def quantize(self, x):
1301
+ batch_size, c, h, w = x.shape # (batch_size, z_channels, h, w)
1302
+
1303
+ x = x.permute(
1304
+ 0, 2, 3, 1
1305
+ ) # (batch_size, z_channels, h, w) -> (batch_size, h, w, z_channels)
1306
+ x = x.reshape(
1307
+ batch_size, -1, c
1308
+ ) # (batch_size, h, w, z_channels) -> (batch_size, h * w, z_channels)
1309
+
1310
+ # Find the nearest codebook vector with distance between (batch_size, h * w, z_channels) and (batch_size, code_book_size, z_channels) -> (batch_size, h * w, code_book_size)
1311
+ dist = torch.cdist(
1312
+ x, self.embedding.weight.unsqueeze(dim=0).repeat((batch_size, 1, 1))
1313
+ ) # cdist calculates the batched p-norm distance
1314
+
1315
+ # (batch_size, h * w) Get the index of the closet codebook vector
1316
+ min_encoding_indices = torch.argmin(dist, dim=-1)
1317
+
1318
+ # Replace the encoder output with the nearest codebook
1319
+ quant_out = torch.index_select(
1320
+ self.embedding.weight, 0, min_encoding_indices.view(-1)
1321
+ ) # (batch_size, h * w, z_channels)
1322
+
1323
+ x = x.reshape((-1, c)) # (batch_size * h * w, z_channels)
1324
+
1325
+ # Commitment and Codebook Loss using mSE
1326
+ commitment_loss = torch.mean((quant_out.detach() - x) ** 2)
1327
+ codebook_loss = torch.mean((quant_out - x.detach()) ** 2)
1328
+
1329
+ quantize_losses = {
1330
+ "codebook_loss": codebook_loss,
1331
+ "commitment_loss": commitment_loss,
1332
+ }
1333
+
1334
+ # Straight through estimation
1335
+ quant_out = x + (quant_out - x).detach()
1336
+
1337
+ quant_out = quant_out.reshape(batch_size, h, w, c).permute(
1338
+ 0, 3, 1, 2
1339
+ ) # (batch_size, z_channels, h, w)
1340
+ min_encoding_indices = min_encoding_indices.reshape(
1341
+ (-1, h, w)
1342
+ ) # (batch_size, h, w)
1343
+
1344
+ return quant_out, quantize_losses, min_encoding_indices
1345
+
1346
+ def encode(self, x):
1347
+ out = self.encoder_conv_in(x) # (batch_size, self.down_channels[0], h, w)
1348
+
1349
+ # (batch_size, self.down_channels[0], h, w) -> (batch_size, self.down_channels[-1], h', w')
1350
+ for idx, down in enumerate(self.encoder_layers):
1351
+ out = down(out)
1352
+
1353
+ # (batch_size, self.down_channels[-1], h', w') -> (batch_size, self.mid_channels[-1], h', w')
1354
+ for mid in self.encoder_mids:
1355
+ out = mid(out)
1356
+
1357
+ out = self.encoder_norm_out(out)
1358
+ out = F.silu(out)
1359
+
1360
+ out = self.encoder_conv_out(
1361
+ out
1362
+ ) # (batch_size, self.mid_channels[-1], h', w') -> (batch_size, self.z_channels, h', w')
1363
+ out = self.pre_quant_conv(
1364
+ out
1365
+ ) # (batch_size, self.z_channels, h', w') -> (batch_size, self.z_channels, h', w')
1366
+
1367
+ out, quant_losses, min_encoding_indices = self.quantize(
1368
+ out
1369
+ ) # (batch_size, self.z_channels, h', w'), (codebook_loss, commitment_loss), (batch_size, h, w)
1370
+ return out, quant_losses
1371
+
1372
+ def decode(self, z):
1373
+ out = z
1374
+ out = self.post_quant_conv(
1375
+ out
1376
+ ) # (batch_size, self.z_channels, h', w') -> (batch_size, self.z_channels, h', w')
1377
+ out = self.decoder_conv_in(
1378
+ out
1379
+ ) # (batch_size, self.z_channels, h', w') -> (batch_size, self.mid_channels[-1], h', w')
1380
+
1381
+ # (batch_size, self.mid_channels[-1], h', w') -> (batch_size, self.down_channels[-1], h', w')
1382
+ for mid in self.decoder_mids:
1383
+ out = mid(out)
1384
+
1385
+ # (batch_size, self.down_channels[-1], h', w') -> (batch_size, self.down_channels[0], h, w)
1386
+ for idx, up in enumerate(self.decoder_layers):
1387
+ out = up(out)
1388
+
1389
+ out = self.decoder_norm_out(out)
1390
+ out = F.silu(out)
1391
+
1392
+ out = self.decoder_conv_out(
1393
+ out
1394
+ ) # (batch_size, self.down_channels[0], h, w) -> (batch_size, c, h, w)
1395
+ return out
1396
+
1397
+ def forward(self, x):
1398
+ # x shape: (batch_size, c, h, w)
1399
+
1400
+ z, quant_losses = self.encode(
1401
+ x
1402
+ ) # (batch_size, self.z_channels, h', w'), (codebook_loss, commitment_loss)
1403
+ out = self.decode(z) # (batch_size, c, h, w)
1404
+
1405
+ return out, z, quant_losses
1406
+
1407
+
1408
+ def validate_image_conditional_input(cond_input, x):
1409
+ assert (
1410
+ "image" in cond_input
1411
+ ), "Model initialized with image conditioning but cond_input has no image information"
1412
+ assert (
1413
+ cond_input["image"].shape[0] == x.shape[0]
1414
+ ), "Batch size mismatch of image condition and input"
1415
+ assert (
1416
+ cond_input["image"].shape[2] % x.shape[2] == 0
1417
+ ), "Height/Width of image condition must be divisible by latent input"
1418
+
1419
+
1420
+ def validate_class_conditional_input(cond_input, x, num_classes):
1421
+ assert (
1422
+ "class" in cond_input
1423
+ ), "Model initialized with class conditioning but cond_input has no class information"
1424
+ assert cond_input["class"].shape == (
1425
+ x.shape[0],
1426
+ num_classes,
1427
+ ), "Shape of class condition input must match (Batch Size, )"
1428
+
1429
+
1430
+ def get_config_value(config, key, default_value):
1431
+ return config[key] if key in config else default_value
1432
+
1433
+
1434
+ class UNet(nn.Module):
1435
+ """
1436
+ Unet model comprising
1437
+ Down blocks, Midblocks and Uplocks
1438
+ """
1439
+
1440
+ def __init__(self, image_channels, model_config):
1441
+ super().__init__()
1442
+
1443
+ self.down_channels = model_config["down_channels"]
1444
+ self.mid_channels = model_config["mid_channels"]
1445
+ self.t_emb_dim = model_config["time_emb_dim"]
1446
+ self.down_sample = model_config["down_sample"]
1447
+ self.num_down_layers = model_config["num_down_layers"]
1448
+ self.num_mid_layers = model_config["num_mid_layers"]
1449
+ self.num_up_layers = model_config["num_up_layers"]
1450
+ self.attns = model_config["attn_down"]
1451
+ self.norm_channels = model_config["norm_channels"]
1452
+ self.num_heads = model_config["num_heads"]
1453
+ self.conv_out_channels = model_config["conv_out_channels"]
1454
+
1455
+ assert self.mid_channels[0] == self.down_channels[-1]
1456
+ assert self.mid_channels[-1] == self.down_channels[-2]
1457
+ assert len(self.down_sample) == len(self.down_channels) - 1
1458
+ assert len(self.attns) == len(self.down_channels) - 1
1459
+
1460
+ # Class, Mask, and Text Conditioning Config
1461
+ self.class_cond = False
1462
+ self.text_cond = False
1463
+ self.image_cond = False
1464
+ self.text_embed_dim = None
1465
+ self.condition_config = get_config_value(
1466
+ model_config, "condition_config", None
1467
+ ) # Get the dictionary containing conditional information
1468
+
1469
+ if self.condition_config is not None:
1470
+ assert (
1471
+ "condition_types" in self.condition_config
1472
+ ), "Condition Type not provided in model config"
1473
+ condition_types = self.condition_config["condition_types"]
1474
+
1475
+ # For class, text, and image, get necessary parameters
1476
+ if "class" in condition_types:
1477
+ self.class_cond = True
1478
+ self.num_classes = self.condition_config["class_condition_config"][
1479
+ "num_classes"
1480
+ ]
1481
+
1482
+ if "text" in condition_types:
1483
+ self.text_cond = True
1484
+ self.text_embed_dim = self.condition_config["text_condition_config"][
1485
+ "text_embed_dim"
1486
+ ]
1487
+
1488
+ if "image" in condition_types:
1489
+ self.image_cond = True
1490
+ self.image_cond_input_channels = self.condition_config[
1491
+ "image_condition_config"
1492
+ ]["image_condition_input_channels"]
1493
+ self.image_cond_output_channels = self.condition_config[
1494
+ "image_condition_config"
1495
+ ]["image_condition_output_channels"]
1496
+
1497
+ if self.class_cond:
1498
+ # For class conditioning, do not add the class embedding information for unconditional generation
1499
+ self.class_emb = nn.Embedding(
1500
+ self.num_classes, self.t_emb_dim
1501
+ ) # (num_classes, t_emb_dim)
1502
+
1503
+ if self.image_cond:
1504
+ # Map the mask image to a image_cond_output_channels channel image, and concat with input across the channel dimension
1505
+ self.cond_conv_in = nn.Conv2d(
1506
+ in_channels=self.image_cond_input_channels,
1507
+ out_channels=self.image_cond_output_channels,
1508
+ kernel_size=1,
1509
+ stride=1,
1510
+ padding=0,
1511
+ bias=False,
1512
+ )
1513
+
1514
+ self.conv_in_concat = nn.Conv2d(
1515
+ in_channels=(image_channels + self.image_cond_output_channels),
1516
+ out_channels=self.down_channels[0],
1517
+ kernel_size=3,
1518
+ stride=1,
1519
+ padding=1,
1520
+ )
1521
+ else:
1522
+ self.conv_in = nn.Conv2d(
1523
+ in_channels=image_channels,
1524
+ out_channels=self.down_channels[0],
1525
+ kernel_size=3,
1526
+ stride=1,
1527
+ padding=1,
1528
+ ) # (batch_size, image_channels, h, w) -> (batch_size, self.down_channels[0], h, w)
1529
+
1530
+ self.cond = self.text_cond or self.image_cond or self.class_cond
1531
+
1532
+ # Initial projection from sinusoidal time embedding
1533
+ self.t_proj = nn.Sequential(
1534
+ nn.Linear(in_features=self.t_emb_dim, out_features=self.t_emb_dim),
1535
+ nn.SiLU(),
1536
+ nn.Linear(in_features=self.t_emb_dim, out_features=self.t_emb_dim),
1537
+ ) # (batch_size, t_emb_dim)
1538
+
1539
+ self.up_sample = list(reversed(self.down_sample))
1540
+
1541
+ self.downs = nn.ModuleList([])
1542
+ for i in range(len(self.down_channels) - 1):
1543
+ # Cross attention and Context Dim are only used for text conditioning
1544
+ self.downs.append(
1545
+ DownBlock(
1546
+ in_channels=self.down_channels[i],
1547
+ out_channels=self.down_channels[i + 1],
1548
+ t_emb_dim=self.t_emb_dim,
1549
+ down_sample=self.down_sample[i],
1550
+ num_heads=self.num_heads,
1551
+ num_layers=self.num_down_layers,
1552
+ attn=self.attns[i],
1553
+ norm_channels=self.norm_channels,
1554
+ cross_attn=self.text_cond,
1555
+ context_dim=self.text_embed_dim,
1556
+ )
1557
+ )
1558
+
1559
+ self.mids = nn.ModuleList([])
1560
+ for i in range(len(self.mid_channels) - 1):
1561
+ # Cross attention and Context Dim are only used for text conditioning
1562
+ self.mids.append(
1563
+ MidBlock(
1564
+ in_channels=self.mid_channels[i],
1565
+ out_channels=self.mid_channels[i + 1],
1566
+ t_emb_dim=self.t_emb_dim,
1567
+ num_heads=self.num_heads,
1568
+ num_layers=self.num_mid_layers,
1569
+ norm_channels=self.norm_channels,
1570
+ cross_attn=self.text_cond,
1571
+ context_dim=self.text_embed_dim,
1572
+ )
1573
+ )
1574
+
1575
+ self.ups = nn.ModuleList([])
1576
+ for i in reversed(range(len(self.down_channels) - 1)):
1577
+ # Cross attention and Context Dim are only used for text conditioning
1578
+ self.ups.append(
1579
+ UpBlockUNet(
1580
+ in_channels=(self.down_channels[i] * 2),
1581
+ out_channels=(
1582
+ self.down_channels[i - 1] if i != 0 else self.conv_out_channels
1583
+ ),
1584
+ t_emb_dim=self.t_emb_dim,
1585
+ up_sample=self.down_sample[i],
1586
+ num_heads=self.num_heads,
1587
+ num_layers=self.num_up_layers,
1588
+ norm_channels=self.norm_channels,
1589
+ cross_attn=self.text_cond,
1590
+ context_dim=self.text_embed_dim,
1591
+ )
1592
+ )
1593
+
1594
+ self.norm_out = nn.GroupNorm(self.norm_channels, self.conv_out_channels)
1595
+
1596
+ self.conv_out = nn.Conv2d(
1597
+ in_channels=self.conv_out_channels,
1598
+ out_channels=image_channels,
1599
+ kernel_size=3,
1600
+ stride=1,
1601
+ padding=1,
1602
+ ) # (batch_size, conv_out_channels, h, w) -> (batch_size, image_channels, h, w)
1603
+
1604
+ def forward(self, x, t, cond_input=None):
1605
+ # x shape: (batch_size, c, h, w)
1606
+ # cond_input is the conditioning vector
1607
+ # For class conditioning, it will be a one-hot vector of size # (batch_size, num_classes)
1608
+
1609
+ if self.cond:
1610
+ assert (
1611
+ cond_input is not None
1612
+ ), "Model initialized with conditioning so cond_input cannot be None"
1613
+
1614
+ if self.image_cond:
1615
+ # Mask Conditioning
1616
+ validate_image_conditional_input(cond_input, x)
1617
+ image_cond = cond_input["image"]
1618
+ image_cond = F.interpolate(image_cond, size=x.shape[-2:])
1619
+ image_cond = self.cond_conv_in(image_cond)
1620
+ assert image_cond.shape[-2:] == x.shape[-2:]
1621
+
1622
+ x = torch.cat(
1623
+ [x, image_cond], dim=1
1624
+ ) # (batch_size, image_channels + image_cond_output_channels, h, w)
1625
+ out = self.conv_in_concat(x) # (batch_size, down_channels[0], h, w)
1626
+ else:
1627
+ out = self.conv_in(x) # (batch_size, down_channels[0], h, w)
1628
+
1629
+ t_emb = get_time_embedding(
1630
+ torch.as_tensor(t).long(), self.t_emb_dim
1631
+ ) # (batch_size, t_emb_dim)
1632
+ t_emb = self.t_proj(t_emb) # (batch_size, t_emb_dim)
1633
+
1634
+ # Class Conditioning
1635
+ if self.class_cond:
1636
+ validate_class_conditional_input(cond_input, x, self.num_classes)
1637
+
1638
+ # Take the matrix for class embedding vectors and matrix multiply it with the embedding matrix to get the class embedding for all images in a batch
1639
+ class_embed = torch.matmul(
1640
+ cond_input["class"].float(), self.class_emb.weight
1641
+ ) # (batch_size, t_emb_dim)
1642
+ t_emb += class_embed # Add the class embedding to the time embedding
1643
+
1644
+ context_hidden_states = None
1645
+
1646
+ # Only use context hidden states in cross-attention for text conditioning
1647
+ if self.text_cond:
1648
+ assert (
1649
+ "text" in cond_input
1650
+ ), "Model initialized with text conditioning but cond_input has no text information"
1651
+ context_hidden_states = cond_input["text"]
1652
+
1653
+ down_outs = []
1654
+ for idx, down in enumerate(self.downs):
1655
+ down_outs.append(out)
1656
+ out = down(
1657
+ out, t_emb, context_hidden_states
1658
+ ) # Use context_hidden_states for cross-attention
1659
+ # out = (batch_size, c4, h / 4, w / 4)
1660
+
1661
+ for mid in self.mids:
1662
+ out = mid(out, t_emb, context_hidden_states)
1663
+ # out = (batch_size, c3, h / 4, w / 4)
1664
+
1665
+ for up in self.ups:
1666
+ down_out = down_outs.pop()
1667
+ out = up(out, down_out, t_emb, context_hidden_states)
1668
+ # out = (batch_size, self.conv_out_channels, h, w)
1669
+
1670
+ out = F.silu(self.norm_out(out))
1671
+ out = self.conv_out(
1672
+ out
1673
+ ) # (batch_size, self.conv_out_channels, h, w) -> (batch_size, image_channels, h, w)
1674
+
1675
+ return out # (batch_size, image_channels, h, w)
1676
+
1677
+
1678
+ def sample_ddpm_inference(
1679
+ unet,
1680
+ vae,
1681
+ text_prompt,
1682
+ mask_image_pil=None,
1683
+ guidance_scale=1.0,
1684
+ device=torch.device("cpu"),
1685
+ ):
1686
+ """
1687
+ Given a text prompt and (optionally) an image condition (as a PIL image),
1688
+ sample from the diffusion model and return a generated image (PIL image).
1689
+ """
1690
+ # Create noise scheduler
1691
+ scheduler = LinearNoiseScheduler(
1692
+ num_timesteps=diffusion_params["num_timesteps"],
1693
+ beta_start=diffusion_params["beta_start"],
1694
+ beta_end=diffusion_params["beta_end"],
1695
+ )
1696
+ # Get conditioning config from ldm_params
1697
+ condition_config = ldm_params.get("condition_config", None)
1698
+ condition_types = (
1699
+ condition_config.get("condition_types", [])
1700
+ if condition_config is not None
1701
+ else []
1702
+ )
1703
+
1704
+ # Load text tokenizer/model for conditioning
1705
+ text_model_type = condition_config["text_condition_config"]["text_embed_model"]
1706
+ text_tokenizer, text_model = get_tokenizer_and_model(text_model_type, device=device)
1707
+
1708
+ # Get empty text representation for classifier-free guidance
1709
+ empty_text_embed = get_text_representation([""], text_tokenizer, text_model, device)
1710
+
1711
+ # Get text representation of the input prompt
1712
+ text_prompt_embed = get_text_representation(
1713
+ [text_prompt], text_tokenizer, text_model, device
1714
+ )
1715
+
1716
+ # Prepare image conditioning:
1717
+ # If the user uploaded a mask image (should be a PIL image), convert it; otherwise, use zeros.
1718
+ if "image" in condition_types:
1719
+ if mask_image_pil is not None:
1720
+ mask_transform = transforms.Compose(
1721
+ [
1722
+ transforms.Resize(
1723
+ (
1724
+ ldm_params["condition_config"]["image_condition_config"][
1725
+ "image_condition_h"
1726
+ ],
1727
+ ldm_params["condition_config"]["image_condition_config"][
1728
+ "image_condition_w"
1729
+ ],
1730
+ )
1731
+ ),
1732
+ transforms.ToTensor(),
1733
+ ]
1734
+ )
1735
+ mask_tensor = (
1736
+ mask_transform(mask_image_pil).unsqueeze(0).to(device)
1737
+ ) # (1, channels, H, W)
1738
+ else:
1739
+ # Create a zero mask with the required number of channels (e.g. 18)
1740
+ ic = ldm_params["condition_config"]["image_condition_config"][
1741
+ "image_condition_input_channels"
1742
+ ]
1743
+ H = ldm_params["condition_config"]["image_condition_config"][
1744
+ "image_condition_h"
1745
+ ]
1746
+ W = ldm_params["condition_config"]["image_condition_config"][
1747
+ "image_condition_w"
1748
+ ]
1749
+ mask_tensor = torch.zeros((1, ic, H, W), device=device)
1750
+ else:
1751
+ mask_tensor = None
1752
+
1753
+ # Build conditioning dictionaries for classifier-free guidance:
1754
+ # For unconditional, we use empty text and zero mask.
1755
+ uncond_input = {}
1756
+ cond_input = {}
1757
+ if "text" in condition_types:
1758
+ uncond_input["text"] = empty_text_embed
1759
+ cond_input["text"] = text_prompt_embed
1760
+ if "image" in condition_types:
1761
+ # Use zeros for unconditioning, and the provided mask for conditioning.
1762
+ uncond_input["image"] = torch.zeros_like(mask_tensor)
1763
+ cond_input["image"] = mask_tensor
1764
+
1765
+ # Load the diffusion UNet (and assume it has been pretrained and saved)
1766
+ # unet = UNet(
1767
+ # image_channels=autoencoder_params["z_channels"], model_config=ldm_params
1768
+ # ).to(device)
1769
+ # ldm_checkpoint_path = os.path.join(
1770
+ # train_params["task_name"], train_params["ldm_ckpt_name"]
1771
+ # )
1772
+ # if os.path.exists(ldm_checkpoint_path):
1773
+ # checkpoint = torch.load(ldm_checkpoint_path, map_location=device)
1774
+ # unet.load_state_dict(checkpoint["model_state_dict"])
1775
+ # unet.eval()
1776
+
1777
+ # Load VQVAE (assume pretrained and saved)
1778
+ # vae = VQVAE(
1779
+ # image_channels=dataset_params["image_channels"], model_config=autoencoder_params
1780
+ # ).to(device)
1781
+ # vae_checkpoint_path = os.path.join(
1782
+ # train_params["task_name"], train_params["vqvae_autoencoder_ckpt_name"]
1783
+ # )
1784
+ # if os.path.exists(vae_checkpoint_path):
1785
+ # checkpoint = torch.load(vae_checkpoint_path, map_location=device)
1786
+ # vae.load_state_dict(checkpoint["model_state_dict"])
1787
+ # vae.eval()
1788
+
1789
+ # Determine latent shape from VQVAE: (batch, z_channels, H_lat, W_lat)
1790
+ # For example, if image_size is 256 and there are 3 downsamplings, H_lat = 256 // 8 = 32.
1791
+ latent_size = dataset_params["image_size"] // (
1792
+ 2 ** sum(autoencoder_params["down_sample"])
1793
+ )
1794
+ batch = train_params["num_samples"]
1795
+ z_channels = autoencoder_params["z_channels"]
1796
+
1797
+ # Sample initial latent noise
1798
+ xt = torch.randn((batch, z_channels, latent_size, latent_size), device=device)
1799
+
1800
+ # Sampling loop (reverse diffusion)
1801
+ T = diffusion_params["num_timesteps"]
1802
+ for i in reversed(range(T)):
1803
+ t = torch.full((batch,), i, dtype=torch.long, device=device)
1804
+ # Get conditional noise prediction
1805
+ noise_pred_cond = unet(xt, t, cond_input)
1806
+ if guidance_scale > 1:
1807
+ noise_pred_uncond = unet(xt, t, uncond_input)
1808
+ noise_pred = noise_pred_uncond + guidance_scale * (
1809
+ noise_pred_cond - noise_pred_uncond
1810
+ )
1811
+ else:
1812
+ noise_pred = noise_pred_cond
1813
+ xt, _ = scheduler.sample_prev_timestep(xt, noise_pred, t)
1814
+
1815
+ with torch.no_grad():
1816
+ generated = vae.decode(xt)
1817
+
1818
+ generated = torch.clamp(generated, -1, 1)
1819
+ generated = (generated + 1) / 2 # scale to [0,1]
1820
+ grid = make_grid(generated, nrow=1)
1821
+ pil_img = transforms.ToPILImage()(grid.cpu())
1822
+
1823
+ yield pil_img
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ transformers
4
+ gradio
5
+ spacy
6
+ datasets
7
+ Pillow