mkthoma commited on
Commit
b05e9ad
·
1 Parent(s): b0b658e

app update

Browse files
Files changed (1) hide show
  1. app.py +271 -0
app.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from base64 import b64encode
2
+
3
+ import numpy as np
4
+ import torch
5
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
6
+
7
+ from matplotlib import pyplot as plt
8
+ from pathlib import Path
9
+ from PIL import Image
10
+ from torch import autocast
11
+ from torchvision import transforms as tfms
12
+ from tqdm.auto import tqdm
13
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
14
+ import os
15
+ import cv2
16
+ import torchvision.transforms as T
17
+
18
+ torch.manual_seed(1)
19
+ logging.set_verbosity_error()
20
+
21
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
22
+
23
+ # Load the autoencoder
24
+ vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder='vae')
25
+
26
+ # Load tokenizer and text encoder to tokenize and encode the text
27
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
28
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
29
+
30
+ # Unet model for generating latents
31
+ unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder='unet')
32
+
33
+ # Noise scheduler
34
+ scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
35
+
36
+ # Move everything to GPU
37
+ vae = vae.to(torch_device)
38
+ text_encoder = text_encoder.to(torch_device)
39
+ unet = unet.to(torch_device)
40
+
41
+ # Prep Scheduler
42
+ def set_timesteps(scheduler, num_inference_steps):
43
+ scheduler.set_timesteps(num_inference_steps)
44
+ scheduler.timesteps = scheduler.timesteps.to(torch.float32) # minor fix to ensure MPS compatibility, fixed in diffusers PR 3925
45
+
46
+ def get_output_embeds(input_embeddings):
47
+ # CLIP's text model uses causal mask, so we prepare it here:
48
+ bsz, seq_len = input_embeddings.shape[:2]
49
+ causal_attention_mask = text_encoder.text_model._build_causal_attention_mask(bsz, seq_len, dtype=input_embeddings.dtype)
50
+
51
+ # Getting the output embeddings involves calling the model with passing output_hidden_states=True
52
+ # so that it doesn't just return the pooled final predictions:
53
+ encoder_outputs = text_encoder.text_model.encoder(
54
+ inputs_embeds=input_embeddings,
55
+ attention_mask=None, # We aren't using an attention mask so that can be None
56
+ causal_attention_mask=causal_attention_mask.to(torch_device),
57
+ output_attentions=None,
58
+ output_hidden_states=True, # We want the output embs not the final output
59
+ return_dict=None,
60
+ )
61
+
62
+ # We're interested in the output hidden state only
63
+ output = encoder_outputs[0]
64
+
65
+ # There is a final layer norm we need to pass these through
66
+ output = text_encoder.text_model.final_layer_norm(output)
67
+
68
+ # And now they're ready!
69
+ return output
70
+
71
+ style_files = ['stable_diffusion/learned_embeddings/arcane-style-jv.bin', 'stable_diffusion/learned_embeddings/birb-style.bin',
72
+ 'stable_diffusion/learned_embeddings/dr-strange.bin', 'stable_diffusion/learned_embeddings/midjourney-style.bin',
73
+ 'stable_diffusion/learned_embeddings/oil_style.bin']
74
+
75
+ def get_style_embeddings(style_file):
76
+ style_embed = torch.load(style_file)
77
+ style_name = list(style_embed.keys())[0]
78
+ return style_embed[style_name]
79
+
80
+ import torch
81
+
82
+ def vibrance_loss(image):
83
+ # Calculate the standard deviation of color channels
84
+ std_dev = torch.std(image, dim=(2, 3)) # Compute standard deviation over height and width
85
+ # Calculate the mean standard deviation across the batch
86
+ mean_std_dev = torch.mean(std_dev)
87
+ # You can adjust a scale factor to control the strength of vibrance regularization
88
+ scale_factor = 100.0
89
+ # Calculate the vibrance loss
90
+ loss = -scale_factor * mean_std_dev
91
+ return loss
92
+
93
+
94
+ from torchvision.transforms import ToTensor
95
+
96
+ def pil_to_latent(input_im):
97
+ # Single image -> single latent in a batch (so size 1, 4, 64, 64)
98
+ with torch.no_grad():
99
+ latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
100
+ return 0.18215 * latent.latent_dist.sample()
101
+
102
+ def latents_to_pil(latents):
103
+ # bath of latents -> list of images
104
+ latents = (1 / 0.18215) * latents
105
+ with torch.no_grad():
106
+ image = vae.decode(latents).sample
107
+ image = (image / 2 + 0.5).clamp(0, 1)
108
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
109
+ images = (image * 255).round().astype("uint8")
110
+ pil_images = [Image.fromarray(image) for image in images]
111
+ return pil_images
112
+
113
+ def additional_guidance(latents, scheduler, noise_pred, t, sigma, custom_loss_fn):
114
+ #### ADDITIONAL GUIDANCE ###
115
+ # Requires grad on the latents
116
+ latents = latents.detach().requires_grad_()
117
+
118
+ # Get the predicted x0:
119
+ latents_x0 = latents - sigma * noise_pred
120
+ #print(f"latents: {latents.shape}, noise_pred:{noise_pred.shape}")
121
+ #latents_x0 = scheduler.step(noise_pred, t, latents).pred_original_sample
122
+
123
+ # Decode to image space
124
+ denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
125
+
126
+ # Calculate loss
127
+ loss = custom_loss_fn(denoised_images)
128
+
129
+ # Get gradient
130
+ cond_grad = torch.autograd.grad(loss, latents, allow_unused=False)[0]
131
+
132
+ # Modify the latents based on this gradient
133
+ latents = latents.detach() - cond_grad * sigma**2
134
+ return latents, loss
135
+
136
+
137
+ def generate_with_embs(text_embeddings, max_length, random_seed, loss_fn = None):
138
+ generator = torch.manual_seed(random_seed) # Seed generator to create the inital latent noise
139
+ batch_size = 1
140
+
141
+ uncond_input = tokenizer(
142
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
143
+ )
144
+ with torch.no_grad():
145
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
146
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
147
+
148
+ # Prep Scheduler
149
+ set_timesteps(scheduler, num_inference_steps)
150
+
151
+ # Prep latents
152
+ latents = torch.randn(
153
+ (batch_size, unet.in_channels, height // 8, width // 8),
154
+ generator=generator,
155
+ )
156
+ latents = latents.to(torch_device)
157
+ latents = latents * scheduler.init_noise_sigma
158
+
159
+ # Loop
160
+ for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
161
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
162
+ latent_model_input = torch.cat([latents] * 2)
163
+ sigma = scheduler.sigmas[i]
164
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
165
+
166
+ # predict the noise residual
167
+ with torch.no_grad():
168
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
169
+
170
+ # perform guidance
171
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
172
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
173
+ if loss_fn is not None:
174
+ if i%2 == 0:
175
+ latents, custom_loss = additional_guidance(latents, scheduler, noise_pred, t, sigma, loss_fn)
176
+
177
+ # compute the previous noisy sample x_t -> x_t-1
178
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
179
+
180
+ return latents_to_pil(latents)[0]
181
+
182
+ def generate_images(prompt, style_num=None, random_seed=41, custom_loss_fn = None):
183
+ eos_pos = get_EOS_pos_in_prompt(prompt)
184
+
185
+ style_token_embedding = None
186
+ if style_num:
187
+ style_token_embedding = get_style_embeddings(style_files[style_num])
188
+
189
+ # tokenize
190
+ text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
191
+ max_length = text_input.input_ids.shape[-1]
192
+ input_ids = text_input.input_ids.to(torch_device)
193
+
194
+ # get token embeddings
195
+ token_emb_layer = text_encoder.text_model.embeddings.token_embedding
196
+ token_embeddings = token_emb_layer(input_ids)
197
+
198
+ # Append style token towards the end of the sentence embeddings
199
+ if style_token_embedding is not None:
200
+ token_embeddings[-1, eos_pos, :] = style_token_embedding
201
+
202
+ # combine with pos embs
203
+ pos_emb_layer = text_encoder.text_model.embeddings.position_embedding
204
+ position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
205
+ position_embeddings = pos_emb_layer(position_ids)
206
+ input_embeddings = token_embeddings + position_embeddings
207
+
208
+ # Feed through to get final output embs
209
+ modified_output_embeddings = get_output_embeds(input_embeddings)
210
+
211
+ # And generate an image with this:
212
+ generated_image = generate_with_embs(modified_output_embeddings, max_length, random_seed, custom_loss_fn)
213
+ return generated_image
214
+
215
+ import matplotlib.pyplot as plt
216
+
217
+ def display_images_in_rows(images_with_titles, titles):
218
+ num_images = len(images_with_titles)
219
+ rows = 5 # Display 5 rows always
220
+ columns = 1 if num_images == 5 else 2 # Use 1 column if there are 5 images, otherwise 2 columns
221
+ fig, axes = plt.subplots(rows, columns + 1, figsize=(15, 5 * rows)) # Add an extra column for titles
222
+
223
+ for r in range(rows):
224
+ # Add the title on the extreme left in the middle of each picture
225
+ axes[r, 0].text(0.5, 0.5, titles[r], ha='center', va='center')
226
+ axes[r, 0].axis('off')
227
+
228
+ # Add "Without Loss" label above the first column and "With Loss" label above the second column (if applicable)
229
+ if columns == 2:
230
+ axes[r, 1].set_title("Without Loss", pad=10)
231
+ axes[r, 2].set_title("With Loss", pad=10)
232
+
233
+ for c in range(1, columns + 1):
234
+ index = r * columns + c - 1
235
+ if index < num_images:
236
+ image, _ = images_with_titles[index]
237
+ axes[r, c].imshow(image)
238
+ axes[r, c].axis('off')
239
+
240
+ plt.show()
241
+
242
+ def image_generator(prompt = "dog", loss_function=None):
243
+
244
+ images_without_loss = []
245
+ images_with_loss = []
246
+
247
+ seed_values = [8,16,50,80,128]
248
+ height = 512 # default height of Stable Diffusion
249
+ width = 512 # default width of Stable Diffusion
250
+ num_inference_steps = 10 # Number of denoising steps
251
+ guidance_scale = 7.5 # Scale for classifier-free guidance
252
+ num_styles = len(style_files)
253
+
254
+ for i in range(num_styles):
255
+ this_generated_img_1 = generate_images(prompt,style_num = i,random_seed = seed_values[i],custom_loss_fn = None)
256
+ images_without_loss.append(this_generated_img_1)
257
+ if loss_function:
258
+ this_generated_img_2 = generate_images(prompt,style_num = i,random_seed = seed_values[i],custom_loss_fn = loss_function)
259
+ images_with_loss.append(this_generated_img_2)
260
+
261
+ generated_sd_images = []
262
+ titles = ["Arcane Style", "Birb Style", "Dr Strange Style", "Midjourney Style", "Oil Style"]
263
+
264
+ for i in range(len(titles)):
265
+ generated_sd_images.append((images_without_loss[i], titles[i]))
266
+ if images_with_loss != []:
267
+ generated_sd_images.append((images_with_loss[i], titles[i]))
268
+
269
+
270
+ return display_images_in_rows(generated_sd_images, titles)
271
+