mskrt commited on
Commit
f828a58
verified
1 Parent(s): 70fbd5b

Delete pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +0 -376
pipeline.py DELETED
@@ -1,376 +0,0 @@
1
- import random
2
- from typing import Callable, Dict, List, Optional
3
-
4
- import torch
5
- from tqdm import tqdm
6
- from diffusers import DiffusionPipeline
7
- from diffusers.configuration_utils import ConfigMixin
8
-
9
- def get_scaled_coeffs():
10
- beta_min = 0.85
11
- beta_max = 12.0
12
- return beta_min**0.5, beta_max**0.5-beta_min**0.5
13
- def beta(t):
14
- a, b = get_scaled_coeffs()
15
- return (a+t*b)**2
16
- def int_beta(t):
17
- a, b = get_scaled_coeffs()
18
- return ((a+b*t)**3-a**3)/(3*b)
19
- def sigma(t):
20
- return torch.expm1(int_beta(t))**0.5
21
- def sigma_orig(t):
22
- return (-torch.expm1(-int_beta(t)))**0.5
23
-
24
- class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin):
25
- """SuperDiffSDXLPipeline."""
26
-
27
- def __init__(self, unet: Callable, vae: Callable, text_encoder: Callable, text_encoder_2: Callable, tokenizer: Callable, tokenizer_2: Callable) -> None:
28
-
29
- """__init__.
30
-
31
- Parameters
32
- ----------
33
- model : Callable
34
- model
35
- vae : Callable
36
- vae
37
- text_encoder : Callable
38
- text_encoder
39
- scheduler : Callable
40
- scheduler
41
- tokenizer : Callable
42
- tokenizer
43
- kwargs :
44
- kwargs
45
-
46
- Returns
47
- -------
48
- None
49
-
50
- """
51
- super().__init__()
52
- device = "cuda" if torch.cuda.is_available() else "cpu"
53
-
54
- vae.to(device)
55
- unet.to(device)
56
- text_encoder.to(device)
57
- text_encoder_2.to(device)
58
-
59
-
60
- self.register_modules(unet=unet,
61
- vae=vae,
62
- text_encoder=text_encoder,
63
- text_encoder_2=text_encoder_2,
64
- tokenizer=tokenizer,
65
- tokenizer_2=tokenizer_2,
66
- )
67
-
68
- def prepare_prompt_input(self, prompt_o, prompt_b, batch_size, height, width):
69
- text_input = self.tokenizer(prompt_o* batch_size, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
70
- text_input_2 = self.tokenizer_2(prompt_o* batch_size, padding="max_length", max_length=self.tokenizer_2.model_max_length, truncation=True, return_tensors="pt")
71
- with torch.no_grad():
72
- text_embeddings = self.text_encoder(text_input.input_ids.to(self.device), output_hidden_states=True)
73
- text_embeddings_2 = self.text_encoder_2(text_input_2.input_ids.to(self.device), output_hidden_states=True)
74
- prompt_embeds_o = torch.concat((text_embeddings.hidden_states[-2], text_embeddings_2.hidden_states[-2]), dim=-1)
75
- pooled_prompt_embeds_o = text_embeddings_2[0]
76
- negative_prompt_embeds = torch.zeros_like(prompt_embeds_o)
77
- negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds_o)
78
-
79
- text_input = self.tokenizer(prompt_b* batch_size, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
80
- text_input_2 = self.tokenizer_2(prompt_b* batch_size, padding="max_length", max_length=self.tokenizer_2.model_max_length, truncation=True, return_tensors="pt")
81
- with torch.no_grad():
82
- text_embeddings = self.text_encoder(text_input.input_ids.to(self.device), output_hidden_states=True)
83
- text_embeddings_2 = self.text_encoder_2(text_input_2.input_ids.to(self.device), output_hidden_states=True)
84
- prompt_embeds_b = torch.concat((text_embeddings.hidden_states[-2], text_embeddings_2.hidden_states[-2]), dim=-1)
85
- pooled_prompt_embeds_b = text_embeddings_2[0]
86
- add_time_ids_o = torch.tensor([(height,width,0,0,height,width)])
87
- add_time_ids_b = torch.tensor([(height,width,0,0,height,width)])
88
- negative_add_time_ids = torch.tensor([(height,width,0,0,height,width)])
89
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds_o, prompt_embeds_b], dim=0)
90
- add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds_o, pooled_prompt_embeds_b], dim=0)
91
- add_time_ids = torch.cat([negative_add_time_ids, add_time_ids_o, add_time_ids_b], dim=0)
92
-
93
- prompt_embeds = prompt_embeds.to(self.device)
94
- add_text_embeds = add_text_embeds.to(self.device)
95
- add_time_ids = add_time_ids.to(self.device).repeat(batch_size, 1)
96
- added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
97
- return prompt_embeds, added_cond_kwargs
98
- @torch.no_grad
99
- def get_batch(self, latents: Callable, nrow: int, ncol: int) -> Callable:
100
- """get_batch.
101
-
102
- Parameters
103
- ----------
104
- latents : Callable
105
- latents
106
- nrow : int
107
- nrow
108
- ncol : int
109
- ncol
110
-
111
- Returns
112
- -------
113
- Callable
114
-
115
- """
116
- image = self.vae.decode(
117
- latents / self.vae.config.scaling_factor, return_dict=False
118
- )[0]
119
- image = (image / 2 + 0.5).clamp(0, 1).squeeze()
120
- if len(image.shape) < 4:
121
- image = image.unsqueeze(0)
122
- image = (image.permute(0, 2, 3, 1) * 255).to(torch.uint8)
123
- return image
124
-
125
- @torch.no_grad
126
- def get_text_embedding(self, prompt: str) -> Callable:
127
- """get_text_embedding.
128
-
129
- Parameters
130
- ----------
131
- prompt : str
132
- prompt
133
-
134
- Returns
135
- -------
136
- Callable
137
-
138
- """
139
- text_input = self.tokenizer(
140
- prompt,
141
- padding="max_length",
142
- max_length=self.tokenizer.model_max_length,
143
- truncation=True,
144
- return_tensors="pt",
145
- )
146
- return self.text_encoder(text_input.input_ids.to(self.device))[0]
147
-
148
- @torch.no_grad
149
- def get_vel(self, t: float, sigma: float, latents: Callable, embeddings: Callable):
150
- """get_vel.
151
-
152
- Parameters
153
- ----------
154
- t : float
155
- t
156
- sigma : float
157
- sigma
158
- latents : Callable
159
- latents
160
- embeddings : Callable
161
- embeddings
162
- """
163
- def v(_x, _e): return self.model(
164
- _x / ((sigma**2 + 1) ** 0.5), t, encoder_hidden_states=_e
165
- ).sample
166
- embeds = torch.cat(embeddings)
167
- latent_input = latents
168
- vel = v(latent_input, embeds)
169
- return vel
170
-
171
- def preprocess(
172
- self,
173
- prompt_1: str,
174
- prompt_2: str,
175
- seed: int = None,
176
- num_inference_steps: int = 1000,
177
- batch_size: int = 1,
178
- lift: int = 0.0,
179
- height: int = 512,
180
- width: int = 512,
181
- guidance_scale: int = 7.5,
182
- ) -> Callable:
183
- """preprocess.
184
-
185
- Parameters
186
- ----------
187
- prompt_1 : str
188
- prompt_1
189
- prompt_2 : str
190
- prompt_2
191
- seed : int
192
- seed
193
- num_inference_steps : int
194
- num_inference_steps
195
- batch_size : int
196
- batch_size
197
- lift : int
198
- lift
199
- height : int
200
- height
201
- width : int
202
- width
203
- guidance_scale : int
204
- guidance_scale
205
-
206
- Returns
207
- -------
208
- Callable
209
-
210
- """
211
- # Tokenize the input
212
- self.batch_size = batch_size
213
- self.num_inference_steps = num_inference_steps
214
- self.guidance_scale = guidance_scale
215
- self.lift = lift
216
- self.seed = seed
217
- if self.seed is None:
218
- self.seed = random.randint(0, 2**32 - 1)
219
-
220
- #obj_prompt = [prompt_1]
221
- #bg_prompt = [prompt_2]
222
- #obj_embeddings = self.get_text_embedding(obj_prompt * batch_size)
223
- #bg_embeddings = self.get_text_embedding(bg_prompt * batch_size)
224
-
225
- #uncond_embeddings = self.get_text_embedding([""] * batch_size)
226
-
227
- generator = torch.cuda.manual_seed(
228
- self.seed
229
- ) # Seed generator to create the initial latent noise
230
- latents = torch.randn((batch_size, self.unet.in_channels, height // 8, width // 8), generator=generator, dtype=self.dtype, device=self.device,)
231
- prompt_embeds, added_cond_kwargs = self.prepare_prompt_input(prompt_1, prompt_2, batch_size, height, width)
232
- #latents = torch.randn(
233
- # (batch_size, self.model.config.in_channels, height // 8, width // 8),
234
- # generator=generator,
235
- # device=self.device,
236
- #)
237
-
238
- #latents_og = latents.clone().detach()
239
- #latents_uncond_og = latents.clone().detach()
240
-
241
- #self.scheduler.set_timesteps(num_inference_steps)
242
- #latents = latents * self.scheduler.init_noise_sigma
243
-
244
- #latents_uncond = latents.clone().detach()
245
- return {
246
- "latents": latents,
247
- "prompt_embeds": prompt_embeds,
248
- "added_cond_kwargs": added_cond_kwargs,
249
- }
250
-
251
- def _forward(self, model_inputs: Dict) -> Callable:
252
- """_forward.
253
-
254
- Parameters
255
- ----------
256
- model_inputs : Dict
257
- model_inputs
258
-
259
- Returns
260
- -------
261
- Callable
262
-
263
- """
264
- latents = model_inputs["latents"]
265
- prompt_embeds = model_inputs["prompt_embeds"]
266
- added_cond_kwargs = model_inputs["added_cond_kwargs"]
267
-
268
- t = torch.tensor(1.0)
269
- dt = 1.0/self.num_inference_steps
270
- train_number_steps = 1000
271
- latents = latents * (sigma(t)**2+1)**0.5
272
- with torch.no_grad():
273
- for i in tqdm(range(self.num_inference_steps)):
274
- latent_model_input = torch.cat([latents] * 3)
275
- sigma_t = sigma(t)
276
- dsigma = sigma(t-dt) - sigma_t
277
- latent_model_input /= (sigma_t**2+1)**0.5
278
- with torch.no_grad():
279
- noise_pred = self.unet(latent_model_input, t*train_number_steps, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs, return_dict=False)[0]
280
-
281
- noise_pred_uncond, noise_pred_text_o, noise_pred_text_b = noise_pred.chunk(3)
282
-
283
- noise = torch.sqrt(2*torch.abs(dsigma)*sigma_t)*torch.randn_like(latents)
284
-
285
-
286
- dx_ind = 2*dsigma*(noise_pred_uncond + self.guidance_scale*(noise_pred_text_b - noise_pred_uncond)) + noise
287
- kappa = (torch.abs(dsigma)*(noise_pred_text_b-noise_pred_text_o)*(noise_pred_text_b+noise_pred_text_o)).sum((1,2,3))-(dx_ind*((noise_pred_text_o-noise_pred_text_b))).sum((1,2,3))
288
- kappa /= 2*dsigma*self.guidance_scale*((noise_pred_text_o-noise_pred_text_b)**2).sum((1,2,3))
289
- noise_pred = noise_pred_uncond + self.guidance_scale*((noise_pred_text_b - noise_pred_uncond) + kappa[:,None,None,None]*(noise_pred_text_o-noise_pred_text_b))
290
-
291
- latents += 2*dsigma * noise_pred + noise
292
- t -= dt
293
- return latents
294
-
295
- def postprocess(self, latents: Callable) -> Callable:
296
- """postprocess.
297
-
298
- Parameters
299
- ----------
300
- latents : Callable
301
- latents
302
-
303
- Returns
304
- -------
305
- Callable
306
-
307
- """
308
- latents = latents/self.vae.config.scaling_factor
309
- latents = latents.to(torch.float32)
310
- with torch.no_grad():
311
- image = self.vae.decode(latents, return_dict=False)[0]
312
-
313
- image = (image / 2 + 0.5).clamp(0, 1)
314
- image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
315
- images = (image * 255).round().astype("uint8")
316
- return images
317
-
318
- def __call__(
319
- self,
320
- prompt_1: str,
321
- prompt_2: str,
322
- seed: int = None,
323
- num_inference_steps: int = 1000,
324
- batch_size: int = 1,
325
- lift: int = 0.0,
326
- height: int = 1024,
327
- width: int = 1024,
328
- guidance_scale: int = 7.5,
329
- ) -> Callable:
330
- """__call__.
331
-
332
- Parameters
333
- ----------
334
- prompt_1 : str
335
- prompt_1
336
- prompt_2 : str
337
- prompt_2
338
- seed : int
339
- seed
340
- num_inference_steps : int
341
- num_inference_steps
342
- batch_size : int
343
- batch_size
344
- lift : int
345
- lift
346
- height : int
347
- height
348
- width : int
349
- width
350
- guidance_scale : int
351
- guidance_scale
352
-
353
- Returns
354
- -------
355
- Callable
356
-
357
- """
358
- # Preprocess inputs
359
- model_inputs = self.preprocess(
360
- prompt_1,
361
- prompt_2,
362
- seed,
363
- num_inference_steps,
364
- batch_size,
365
- lift,
366
- height,
367
- width,
368
- guidance_scale,
369
- )
370
-
371
- # Forward pass through the pipeline
372
- latents = self._forward(model_inputs)
373
-
374
- # Postprocess to generate the final output
375
- images = self.postprocess(latents)
376
- return images