mskrt commited on
Commit
3858858
verified
1 Parent(s): d5ebdc6

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +386 -0
pipeline.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Callable, Dict, List, Optional
3
+
4
+ import torch
5
+ from tqdm import tqdm
6
+ from diffusers import DiffusionPipeline
7
+ from diffusers.configuration_utils import ConfigMixin
8
+
9
+ def get_scaled_coeffs():
10
+ beta_min = 0.85
11
+ beta_max = 12.0
12
+ return beta_min**0.5, beta_max**0.5-beta_min**0.5
13
+ def beta(t):
14
+ a, b = get_scaled_coeffs()
15
+ return (a+t*b)**2
16
+ def int_beta(t):
17
+ a, b = get_scaled_coeffs()
18
+ return ((a+b*t)**3-a**3)/(3*b)
19
+ def sigma(t):
20
+ return torch.expm1(int_beta(t))**0.5
21
+ def sigma_orig(t):
22
+ return (-torch.expm1(-int_beta(t)))**0.5
23
+
24
+ class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin):
25
+ """SuperDiffSDXLPipeline."""
26
+
27
+ def __init__(self, unet: Callable, vae: Callable, text_encoder: Callable, text_encoder_2: Callable, tokenizer: Callable, tokenizer_2: Callable) -> None:
28
+
29
+ """__init__.
30
+
31
+ Parameters
32
+ ----------
33
+ model : Callable
34
+ model
35
+ vae : Callable
36
+ vae
37
+ text_encoder : Callable
38
+ text_encoder
39
+ scheduler : Callable
40
+ scheduler
41
+ tokenizer : Callable
42
+ tokenizer
43
+ kwargs :
44
+ kwargs
45
+
46
+ Returns
47
+ -------
48
+ None
49
+
50
+ """
51
+ super().__init__()
52
+ device = "cuda" if torch.cuda.is_available() else "cpu"
53
+ print("decice", device)
54
+
55
+ vae.to(device)
56
+ unet.to(device)
57
+ text_encoder.to(device)
58
+ text_encoder_2.to(device)
59
+ #dtype = torch.float16
60
+ #vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae").to(device)
61
+ #tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
62
+ #tokenizer_2 = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer_2")
63
+ #text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder").to(device, dtype=dtype)
64
+ #text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(model_path, subfolder="text_encoder_2").to(device, dtype=dtype)
65
+ #unet = UNet2DConditionModel.from_pretrained(model_path, subfolder="unet").to(device, dtype=dtype)
66
+ #vae.eval()
67
+ #unet.eval()
68
+
69
+ self.register_modules(unet=unet,
70
+ vae=vae,
71
+ text_encoder=text_encoder,
72
+ text_encoder_2=text_encoder_2,
73
+ tokenizer=tokenizer,
74
+ tokenizer_2=tokenizer_2,
75
+ )
76
+ print("decice2", self.device)
77
+ def prepare_prompt_input(self, prompt_o, prompt_b, batch_size, height, width):
78
+ print("self.device", self.device)
79
+ text_input = self.tokenizer(prompt_o* batch_size, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
80
+ text_input_2 = self.tokenizer_2(prompt_o* batch_size, padding="max_length", max_length=self.tokenizer_2.model_max_length, truncation=True, return_tensors="pt")
81
+ with torch.no_grad():
82
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device), output_hidden_states=True)
83
+ text_embeddings_2 = self.text_encoder_2(text_input_2.input_ids.to(self.device), output_hidden_states=True)
84
+ prompt_embeds_o = torch.concat((text_embeddings.hidden_states[-2], text_embeddings_2.hidden_states[-2]), dim=-1)
85
+ pooled_prompt_embeds_o = text_embeddings_2[0]
86
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds_o)
87
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds_o)
88
+
89
+ text_input = self.tokenizer(prompt_b* batch_size, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
90
+ text_input_2 = self.tokenizer_2(prompt_b* batch_size, padding="max_length", max_length=self.tokenizer_2.model_max_length, truncation=True, return_tensors="pt")
91
+ with torch.no_grad():
92
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device), output_hidden_states=True)
93
+ text_embeddings_2 = self.text_encoder_2(text_input_2.input_ids.to(self.device), output_hidden_states=True)
94
+ prompt_embeds_b = torch.concat((text_embeddings.hidden_states[-2], text_embeddings_2.hidden_states[-2]), dim=-1)
95
+ pooled_prompt_embeds_b = text_embeddings_2[0]
96
+ add_time_ids_o = torch.tensor([(height,width,0,0,height,width)])
97
+ add_time_ids_b = torch.tensor([(height,width,0,0,height,width)])
98
+ negative_add_time_ids = torch.tensor([(height,width,0,0,height,width)])
99
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds_o, prompt_embeds_b], dim=0)
100
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds_o, pooled_prompt_embeds_b], dim=0)
101
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids_o, add_time_ids_b], dim=0)
102
+
103
+ prompt_embeds = prompt_embeds.to(self.device)
104
+ add_text_embeds = add_text_embeds.to(self.device)
105
+ add_time_ids = add_time_ids.to(self.device).repeat(batch_size, 1)
106
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
107
+ return prompt_embeds, added_cond_kwargs
108
+ @torch.no_grad
109
+ def get_batch(self, latents: Callable, nrow: int, ncol: int) -> Callable:
110
+ """get_batch.
111
+
112
+ Parameters
113
+ ----------
114
+ latents : Callable
115
+ latents
116
+ nrow : int
117
+ nrow
118
+ ncol : int
119
+ ncol
120
+
121
+ Returns
122
+ -------
123
+ Callable
124
+
125
+ """
126
+ image = self.vae.decode(
127
+ latents / self.vae.config.scaling_factor, return_dict=False
128
+ )[0]
129
+ image = (image / 2 + 0.5).clamp(0, 1).squeeze()
130
+ if len(image.shape) < 4:
131
+ image = image.unsqueeze(0)
132
+ image = (image.permute(0, 2, 3, 1) * 255).to(torch.uint8)
133
+ return image
134
+
135
+ @torch.no_grad
136
+ def get_text_embedding(self, prompt: str) -> Callable:
137
+ """get_text_embedding.
138
+
139
+ Parameters
140
+ ----------
141
+ prompt : str
142
+ prompt
143
+
144
+ Returns
145
+ -------
146
+ Callable
147
+
148
+ """
149
+ text_input = self.tokenizer(
150
+ prompt,
151
+ padding="max_length",
152
+ max_length=self.tokenizer.model_max_length,
153
+ truncation=True,
154
+ return_tensors="pt",
155
+ )
156
+ return self.text_encoder(text_input.input_ids.to(self.device))[0]
157
+
158
+ @torch.no_grad
159
+ def get_vel(self, t: float, sigma: float, latents: Callable, embeddings: Callable):
160
+ """get_vel.
161
+
162
+ Parameters
163
+ ----------
164
+ t : float
165
+ t
166
+ sigma : float
167
+ sigma
168
+ latents : Callable
169
+ latents
170
+ embeddings : Callable
171
+ embeddings
172
+ """
173
+ def v(_x, _e): return self.model(
174
+ _x / ((sigma**2 + 1) ** 0.5), t, encoder_hidden_states=_e
175
+ ).sample
176
+ embeds = torch.cat(embeddings)
177
+ latent_input = latents
178
+ vel = v(latent_input, embeds)
179
+ return vel
180
+
181
+ def preprocess(
182
+ self,
183
+ prompt_1: str,
184
+ prompt_2: str,
185
+ seed: int = None,
186
+ num_inference_steps: int = 1000,
187
+ batch_size: int = 1,
188
+ lift: int = 0.0,
189
+ height: int = 512,
190
+ width: int = 512,
191
+ guidance_scale: int = 7.5,
192
+ ) -> Callable:
193
+ """preprocess.
194
+
195
+ Parameters
196
+ ----------
197
+ prompt_1 : str
198
+ prompt_1
199
+ prompt_2 : str
200
+ prompt_2
201
+ seed : int
202
+ seed
203
+ num_inference_steps : int
204
+ num_inference_steps
205
+ batch_size : int
206
+ batch_size
207
+ lift : int
208
+ lift
209
+ height : int
210
+ height
211
+ width : int
212
+ width
213
+ guidance_scale : int
214
+ guidance_scale
215
+
216
+ Returns
217
+ -------
218
+ Callable
219
+
220
+ """
221
+ # Tokenize the input
222
+ self.batch_size = batch_size
223
+ self.num_inference_steps = num_inference_steps
224
+ self.guidance_scale = guidance_scale
225
+ self.lift = lift
226
+ self.seed = seed
227
+ if self.seed is None:
228
+ self.seed = random.randint(0, 2**32 - 1)
229
+
230
+ #obj_prompt = [prompt_1]
231
+ #bg_prompt = [prompt_2]
232
+ #obj_embeddings = self.get_text_embedding(obj_prompt * batch_size)
233
+ #bg_embeddings = self.get_text_embedding(bg_prompt * batch_size)
234
+
235
+ #uncond_embeddings = self.get_text_embedding([""] * batch_size)
236
+
237
+ generator = torch.cuda.manual_seed(
238
+ self.seed
239
+ ) # Seed generator to create the initial latent noise
240
+ latents = torch.randn((batch_size, self.unet.in_channels, height // 8, width // 8), generator=generator, dtype=self.dtype, device=self.device,)
241
+ prompt_embeds, added_cond_kwargs = self.prepare_prompt_input(prompt_1, prompt_2, batch_size, height, width)
242
+ #latents = torch.randn(
243
+ # (batch_size, self.model.config.in_channels, height // 8, width // 8),
244
+ # generator=generator,
245
+ # device=self.device,
246
+ #)
247
+
248
+ #latents_og = latents.clone().detach()
249
+ #latents_uncond_og = latents.clone().detach()
250
+
251
+ #self.scheduler.set_timesteps(num_inference_steps)
252
+ #latents = latents * self.scheduler.init_noise_sigma
253
+
254
+ #latents_uncond = latents.clone().detach()
255
+ return {
256
+ "latents": latents,
257
+ "prompt_embeds": prompt_embeds,
258
+ "added_cond_kwargs": added_cond_kwargs,
259
+ }
260
+
261
+ def _forward(self, model_inputs: Dict) -> Callable:
262
+ """_forward.
263
+
264
+ Parameters
265
+ ----------
266
+ model_inputs : Dict
267
+ model_inputs
268
+
269
+ Returns
270
+ -------
271
+ Callable
272
+
273
+ """
274
+ latents = model_inputs["latents"]
275
+ prompt_embeds = model_inputs["prompt_embeds"]
276
+ added_cond_kwargs = model_inputs["added_cond_kwargs"]
277
+
278
+ t = torch.tensor(1.0)
279
+ dt = 1.0/self.num_inference_steps
280
+ train_number_steps = 1000
281
+ latents = latents * (sigma(t)**2+1)**0.5
282
+ with torch.no_grad():
283
+ for i in tqdm(range(self.num_inference_steps)):
284
+ latent_model_input = torch.cat([latents] * 3)
285
+ sigma_t = sigma(t)
286
+ dsigma = sigma(t-dt) - sigma_t
287
+ latent_model_input /= (sigma_t**2+1)**0.5
288
+ with torch.no_grad():
289
+ noise_pred = self.unet(latent_model_input, t*train_number_steps, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs, return_dict=False)[0]
290
+
291
+ noise_pred_uncond, noise_pred_text_o, noise_pred_text_b = noise_pred.chunk(3)
292
+
293
+ noise = torch.sqrt(2*torch.abs(dsigma)*sigma_t)*torch.randn_like(latents)
294
+
295
+
296
+ dx_ind = 2*dsigma*(noise_pred_uncond + self.guidance_scale*(noise_pred_text_b - noise_pred_uncond)) + noise
297
+ kappa = (torch.abs(dsigma)*(noise_pred_text_b-noise_pred_text_o)*(noise_pred_text_b+noise_pred_text_o)).sum((1,2,3))-(dx_ind*((noise_pred_text_o-noise_pred_text_b))).sum((1,2,3))
298
+ kappa /= 2*dsigma*self.guidance_scale*((noise_pred_text_o-noise_pred_text_b)**2).sum((1,2,3))
299
+ noise_pred = noise_pred_uncond + self.guidance_scale*((noise_pred_text_b - noise_pred_uncond) + kappa[:,None,None,None]*(noise_pred_text_o-noise_pred_text_b))
300
+
301
+ latents += 2*dsigma * noise_pred + noise
302
+ t -= dt
303
+ return latents
304
+
305
+ def postprocess(self, latents: Callable) -> Callable:
306
+ """postprocess.
307
+
308
+ Parameters
309
+ ----------
310
+ latents : Callable
311
+ latents
312
+
313
+ Returns
314
+ -------
315
+ Callable
316
+
317
+ """
318
+ latents = latents/self.vae.config.scaling_factor
319
+ latents = latents.to(torch.float32)
320
+ with torch.no_grad():
321
+ image = self.vae.decode(latents, return_dict=False)[0]
322
+
323
+ image = (image / 2 + 0.5).clamp(0, 1)
324
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
325
+ images = (image * 255).round().astype("uint8")
326
+ return images
327
+
328
+ def __call__(
329
+ self,
330
+ prompt_1: str,
331
+ prompt_2: str,
332
+ seed: int = None,
333
+ num_inference_steps: int = 1000,
334
+ batch_size: int = 1,
335
+ lift: int = 0.0,
336
+ height: int = 512,
337
+ width: int = 512,
338
+ guidance_scale: int = 7.5,
339
+ ) -> Callable:
340
+ """__call__.
341
+
342
+ Parameters
343
+ ----------
344
+ prompt_1 : str
345
+ prompt_1
346
+ prompt_2 : str
347
+ prompt_2
348
+ seed : int
349
+ seed
350
+ num_inference_steps : int
351
+ num_inference_steps
352
+ batch_size : int
353
+ batch_size
354
+ lift : int
355
+ lift
356
+ height : int
357
+ height
358
+ width : int
359
+ width
360
+ guidance_scale : int
361
+ guidance_scale
362
+
363
+ Returns
364
+ -------
365
+ Callable
366
+
367
+ """
368
+ # Preprocess inputs
369
+ model_inputs = self.preprocess(
370
+ prompt_1,
371
+ prompt_2,
372
+ seed,
373
+ num_inference_steps,
374
+ batch_size,
375
+ lift,
376
+ height,
377
+ width,
378
+ guidance_scale,
379
+ )
380
+
381
+ # Forward pass through the pipeline
382
+ latents = self._forward(model_inputs)
383
+
384
+ # Postprocess to generate the final output
385
+ images = self.postprocess(latents)
386
+ return images