mskrt commited on
Commit
dd1fefc
verified
1 Parent(s): 045bb1e

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +145 -70
pipeline.py CHANGED
@@ -1,5 +1,5 @@
1
  import random
2
- from typing import Callable, Dict, List, Optional
3
 
4
  import torch
5
  from diffusers import DiffusionPipeline
@@ -11,11 +11,10 @@ from tqdm import tqdm
11
 
12
 
13
  def get_scaled_coeffs():
14
- """get_scaled_coeffs.
15
- """
16
  beta_min = 0.85
17
  beta_max = 12.0
18
- return beta_min**0.5, beta_max**0.5-beta_min**0.5
19
 
20
 
21
  def beta(t):
@@ -27,7 +26,7 @@ def beta(t):
27
  t
28
  """
29
  a, b = get_scaled_coeffs()
30
- return (a+t*b)**2
31
 
32
 
33
  def int_beta(t):
@@ -39,7 +38,7 @@ def int_beta(t):
39
  t
40
  """
41
  a, b = get_scaled_coeffs()
42
- return ((a+b*t)**3-a**3)/(3*b)
43
 
44
 
45
  def sigma(t):
@@ -50,7 +49,7 @@ def sigma(t):
50
  t :
51
  t
52
  """
53
- return torch.expm1(int_beta(t))**0.5
54
 
55
 
56
  def sigma_orig(t):
@@ -61,13 +60,21 @@ def sigma_orig(t):
61
  t :
62
  t
63
  """
64
- return (-torch.expm1(-int_beta(t)))**0.5
65
 
66
 
67
  class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin):
68
  """SuperDiffSDXLPipeline."""
69
 
70
- def __init__(self, unet: Callable, vae: Callable, text_encoder: Callable, text_encoder_2: Callable, tokenizer: Callable, tokenizer_2: Callable) -> None:
 
 
 
 
 
 
 
 
71
  """__init__.
72
 
73
  Parameters
@@ -99,13 +106,14 @@ class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin):
99
  text_encoder.to(device)
100
  text_encoder_2.to(device)
101
 
102
- self.register_modules(unet=unet,
103
- vae=vae,
104
- text_encoder=text_encoder,
105
- text_encoder_2=text_encoder_2,
106
- tokenizer=tokenizer,
107
- tokenizer_2=tokenizer_2,
108
- )
 
109
 
110
  def prepare_prompt_input(self, prompt_o, prompt_b, batch_size, height, width):
111
  """prepare_prompt_input.
@@ -123,44 +131,82 @@ class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin):
123
  width :
124
  width
125
  """
126
- text_input = self.tokenizer(prompt_o * batch_size, padding="max_length",
127
- max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
128
- text_input_2 = self.tokenizer_2(prompt_o * batch_size, padding="max_length",
129
- max_length=self.tokenizer_2.model_max_length, truncation=True, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
130
  with torch.no_grad():
131
  text_embeddings = self.text_encoder(
132
- text_input.input_ids.to(self.device), output_hidden_states=True)
 
133
  text_embeddings_2 = self.text_encoder_2(
134
- text_input_2.input_ids.to(self.device), output_hidden_states=True)
 
135
  prompt_embeds_o = torch.concat(
136
- (text_embeddings.hidden_states[-2], text_embeddings_2.hidden_states[-2]), dim=-1)
 
 
 
137
  pooled_prompt_embeds_o = text_embeddings_2[0]
138
  negative_prompt_embeds = torch.zeros_like(prompt_embeds_o)
139
  negative_pooled_prompt_embeds = torch.zeros_like(
140
  pooled_prompt_embeds_o)
141
 
142
- text_input = self.tokenizer(prompt_b * batch_size, padding="max_length",
143
- max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
144
- text_input_2 = self.tokenizer_2(prompt_b * batch_size, padding="max_length",
145
- max_length=self.tokenizer_2.model_max_length, truncation=True, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
146
  with torch.no_grad():
147
  text_embeddings = self.text_encoder(
148
- text_input.input_ids.to(self.device), output_hidden_states=True)
 
149
  text_embeddings_2 = self.text_encoder_2(
150
- text_input_2.input_ids.to(self.device), output_hidden_states=True)
 
151
  prompt_embeds_b = torch.concat(
152
- (text_embeddings.hidden_states[-2], text_embeddings_2.hidden_states[-2]), dim=-1)
 
 
 
153
  pooled_prompt_embeds_b = text_embeddings_2[0]
154
  add_time_ids_o = torch.tensor([(height, width, 0, 0, height, width)])
155
  add_time_ids_b = torch.tensor([(height, width, 0, 0, height, width)])
156
  negative_add_time_ids = torch.tensor(
157
  [(height, width, 0, 0, height, width)])
158
  prompt_embeds = torch.cat(
159
- [negative_prompt_embeds, prompt_embeds_o, prompt_embeds_b], dim=0)
 
160
  add_text_embeds = torch.cat(
161
- [negative_pooled_prompt_embeds, pooled_prompt_embeds_o, pooled_prompt_embeds_b], dim=0)
 
 
 
 
 
 
162
  add_time_ids = torch.cat(
163
- [negative_add_time_ids, add_time_ids_o, add_time_ids_b], dim=0)
 
164
 
165
  prompt_embeds = prompt_embeds.to(self.device)
166
  add_text_embeds = add_text_embeds.to(self.device)
@@ -234,16 +280,8 @@ class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin):
234
  embeddings : Callable
235
  embeddings
236
  """
237
- def v(_x, _e): return self.model(
238
- """v.
239
 
240
- Parameters
241
- ----------
242
- _x :
243
- _x
244
- _e :
245
- _e
246
- """
247
  """v.
248
 
249
  Parameters
@@ -253,8 +291,10 @@ class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin):
253
  _e :
254
  _e
255
  """
256
- _x / ((sigma**2 + 1) ** 0.5), t, encoder_hidden_states=_e
257
- ).sample
 
 
258
  embeds = torch.cat(embeddings)
259
  latent_input = latents
260
  vel = v(latent_input, embeds)
@@ -309,10 +349,15 @@ class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin):
309
  self.seed
310
  ) # Seed generator to create the initial latent noise
311
 
312
- latents = torch.randn((batch_size, self.unet.in_channels, height // 8, width // 8),
313
- generator=self.generator, dtype=self.dtype, device=self.device,)
 
 
 
 
314
  prompt_embeds, added_cond_kwargs = self.prepare_prompt_input(
315
- prompt_1, prompt_2, batch_size, height, width)
 
316
 
317
  return {
318
  "latents": latents,
@@ -338,38 +383,68 @@ class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin):
338
  added_cond_kwargs = model_inputs["added_cond_kwargs"]
339
 
340
  t = torch.tensor(1.0)
341
- dt = 1.0/self.num_inference_steps
342
  train_number_steps = 1000
343
- latents = latents * (sigma(t)**2+1)**0.5
344
  with torch.no_grad():
345
  for i in tqdm(range(self.num_inference_steps)):
346
  latent_model_input = torch.cat([latents] * 3)
347
  sigma_t = sigma(t)
348
- dsigma = sigma(t-dt) - sigma_t
349
- latent_model_input /= (sigma_t**2+1)**0.5
350
  with torch.no_grad():
351
- noise_pred = self.unet(latent_model_input, t*train_number_steps, encoder_hidden_states=prompt_embeds,
352
- added_cond_kwargs=added_cond_kwargs, return_dict=False)[0]
353
-
354
- noise_pred_uncond, noise_pred_text_o, noise_pred_text_b = noise_pred.chunk(
355
- 3)
 
 
 
 
 
 
 
 
356
 
357
  # noise = torch.sqrt(2*torch.abs(dsigma)*sigma_t)*torch.randn_like(latents)
358
- noise = torch.sqrt(2*torch.abs(dsigma)*sigma_t)*torch.empty_like(
359
- latents, device=self.device).normal_(generator=self.generator)
360
-
361
- dx_ind = 2*dsigma*(noise_pred_uncond + self.guidance_scale *
362
- (noise_pred_text_b - noise_pred_uncond)) + noise
363
- kappa = (torch.abs(dsigma)*(noise_pred_text_b-noise_pred_text_o)*(noise_pred_text_b+noise_pred_text_o)
364
- ).sum((1, 2, 3))-(dx_ind*((noise_pred_text_o-noise_pred_text_b))).sum((1, 2, 3))
365
- kappa /= 2*dsigma*self.guidance_scale * \
366
- ((noise_pred_text_o-noise_pred_text_b)**2).sum((1, 2, 3))
367
- noise_pred = noise_pred_uncond + self.guidance_scale * \
368
- ((noise_pred_text_b - noise_pred_uncond) +
369
- kappa[:, None, None, None]*(noise_pred_text_o-noise_pred_text_b))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
371
  if i < self.num_inference_steps - 1:
372
- latents += 2*dsigma * noise_pred + noise
373
  else:
374
  latents += dsigma * noise_pred
375
 
@@ -389,7 +464,7 @@ class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin):
389
  Callable
390
 
391
  """
392
- latents = latents/self.vae.config.scaling_factor
393
  latents = latents.to(torch.float32)
394
  with torch.no_grad():
395
  image = self.vae.decode(latents, return_dict=False)[0]
 
1
  import random
2
+ from typing import Callable, Dict
3
 
4
  import torch
5
  from diffusers import DiffusionPipeline
 
11
 
12
 
13
  def get_scaled_coeffs():
14
+ """get_scaled_coeffs."""
 
15
  beta_min = 0.85
16
  beta_max = 12.0
17
+ return beta_min**0.5, beta_max**0.5 - beta_min**0.5
18
 
19
 
20
  def beta(t):
 
26
  t
27
  """
28
  a, b = get_scaled_coeffs()
29
+ return (a + t * b) ** 2
30
 
31
 
32
  def int_beta(t):
 
38
  t
39
  """
40
  a, b = get_scaled_coeffs()
41
+ return ((a + b * t) ** 3 - a**3) / (3 * b)
42
 
43
 
44
  def sigma(t):
 
49
  t :
50
  t
51
  """
52
+ return torch.expm1(int_beta(t)) ** 0.5
53
 
54
 
55
  def sigma_orig(t):
 
60
  t :
61
  t
62
  """
63
+ return (-torch.expm1(-int_beta(t))) ** 0.5
64
 
65
 
66
  class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin):
67
  """SuperDiffSDXLPipeline."""
68
 
69
+ def __init__(
70
+ self,
71
+ unet: Callable,
72
+ vae: Callable,
73
+ text_encoder: Callable,
74
+ text_encoder_2: Callable,
75
+ tokenizer: Callable,
76
+ tokenizer_2: Callable,
77
+ ) -> None:
78
  """__init__.
79
 
80
  Parameters
 
106
  text_encoder.to(device)
107
  text_encoder_2.to(device)
108
 
109
+ self.register_modules(
110
+ unet=unet,
111
+ vae=vae,
112
+ text_encoder=text_encoder,
113
+ text_encoder_2=text_encoder_2,
114
+ tokenizer=tokenizer,
115
+ tokenizer_2=tokenizer_2,
116
+ )
117
 
118
  def prepare_prompt_input(self, prompt_o, prompt_b, batch_size, height, width):
119
  """prepare_prompt_input.
 
131
  width :
132
  width
133
  """
134
+ text_input = self.tokenizer(
135
+ prompt_o * batch_size,
136
+ padding="max_length",
137
+ max_length=self.tokenizer.model_max_length,
138
+ truncation=True,
139
+ return_tensors="pt",
140
+ )
141
+ text_input_2 = self.tokenizer_2(
142
+ prompt_o * batch_size,
143
+ padding="max_length",
144
+ max_length=self.tokenizer_2.model_max_length,
145
+ truncation=True,
146
+ return_tensors="pt",
147
+ )
148
  with torch.no_grad():
149
  text_embeddings = self.text_encoder(
150
+ text_input.input_ids.to(self.device), output_hidden_states=True
151
+ )
152
  text_embeddings_2 = self.text_encoder_2(
153
+ text_input_2.input_ids.to(self.device), output_hidden_states=True
154
+ )
155
  prompt_embeds_o = torch.concat(
156
+ (text_embeddings.hidden_states[-2],
157
+ text_embeddings_2.hidden_states[-2]),
158
+ dim=-1,
159
+ )
160
  pooled_prompt_embeds_o = text_embeddings_2[0]
161
  negative_prompt_embeds = torch.zeros_like(prompt_embeds_o)
162
  negative_pooled_prompt_embeds = torch.zeros_like(
163
  pooled_prompt_embeds_o)
164
 
165
+ text_input = self.tokenizer(
166
+ prompt_b * batch_size,
167
+ padding="max_length",
168
+ max_length=self.tokenizer.model_max_length,
169
+ truncation=True,
170
+ return_tensors="pt",
171
+ )
172
+ text_input_2 = self.tokenizer_2(
173
+ prompt_b * batch_size,
174
+ padding="max_length",
175
+ max_length=self.tokenizer_2.model_max_length,
176
+ truncation=True,
177
+ return_tensors="pt",
178
+ )
179
  with torch.no_grad():
180
  text_embeddings = self.text_encoder(
181
+ text_input.input_ids.to(self.device), output_hidden_states=True
182
+ )
183
  text_embeddings_2 = self.text_encoder_2(
184
+ text_input_2.input_ids.to(self.device), output_hidden_states=True
185
+ )
186
  prompt_embeds_b = torch.concat(
187
+ (text_embeddings.hidden_states[-2],
188
+ text_embeddings_2.hidden_states[-2]),
189
+ dim=-1,
190
+ )
191
  pooled_prompt_embeds_b = text_embeddings_2[0]
192
  add_time_ids_o = torch.tensor([(height, width, 0, 0, height, width)])
193
  add_time_ids_b = torch.tensor([(height, width, 0, 0, height, width)])
194
  negative_add_time_ids = torch.tensor(
195
  [(height, width, 0, 0, height, width)])
196
  prompt_embeds = torch.cat(
197
+ [negative_prompt_embeds, prompt_embeds_o, prompt_embeds_b], dim=0
198
+ )
199
  add_text_embeds = torch.cat(
200
+ [
201
+ negative_pooled_prompt_embeds,
202
+ pooled_prompt_embeds_o,
203
+ pooled_prompt_embeds_b,
204
+ ],
205
+ dim=0,
206
+ )
207
  add_time_ids = torch.cat(
208
+ [negative_add_time_ids, add_time_ids_o, add_time_ids_b], dim=0
209
+ )
210
 
211
  prompt_embeds = prompt_embeds.to(self.device)
212
  add_text_embeds = add_text_embeds.to(self.device)
 
280
  embeddings : Callable
281
  embeddings
282
  """
 
 
283
 
284
+ def v(_x, _e):
 
 
 
 
 
 
285
  """v.
286
 
287
  Parameters
 
291
  _e :
292
  _e
293
  """
294
+ return self.model(
295
+ _x / ((sigma**2 + 1) ** 0.5), t, encoder_hidden_states=_e
296
+ ).sample
297
+
298
  embeds = torch.cat(embeddings)
299
  latent_input = latents
300
  vel = v(latent_input, embeds)
 
349
  self.seed
350
  ) # Seed generator to create the initial latent noise
351
 
352
+ latents = torch.randn(
353
+ (batch_size, self.unet.in_channels, height // 8, width // 8),
354
+ generator=self.generator,
355
+ dtype=self.dtype,
356
+ device=self.device,
357
+ )
358
  prompt_embeds, added_cond_kwargs = self.prepare_prompt_input(
359
+ prompt_1, prompt_2, batch_size, height, width
360
+ )
361
 
362
  return {
363
  "latents": latents,
 
383
  added_cond_kwargs = model_inputs["added_cond_kwargs"]
384
 
385
  t = torch.tensor(1.0)
386
+ dt = 1.0 / self.num_inference_steps
387
  train_number_steps = 1000
388
+ latents = latents * (sigma(t) ** 2 + 1) ** 0.5
389
  with torch.no_grad():
390
  for i in tqdm(range(self.num_inference_steps)):
391
  latent_model_input = torch.cat([latents] * 3)
392
  sigma_t = sigma(t)
393
+ dsigma = sigma(t - dt) - sigma_t
394
+ latent_model_input /= (sigma_t**2 + 1) ** 0.5
395
  with torch.no_grad():
396
+ noise_pred = self.unet(
397
+ latent_model_input,
398
+ t * train_number_steps,
399
+ encoder_hidden_states=prompt_embeds,
400
+ added_cond_kwargs=added_cond_kwargs,
401
+ return_dict=False,
402
+ )[0]
403
+
404
+ (
405
+ noise_pred_uncond,
406
+ noise_pred_text_o,
407
+ noise_pred_text_b,
408
+ ) = noise_pred.chunk(3)
409
 
410
  # noise = torch.sqrt(2*torch.abs(dsigma)*sigma_t)*torch.randn_like(latents)
411
+ noise = torch.sqrt(2 * torch.abs(dsigma) * sigma_t) * torch.empty_like(
412
+ latents, device=self.device
413
+ ).normal_(generator=self.generator)
414
+
415
+ dx_ind = (
416
+ 2
417
+ * dsigma
418
+ * (
419
+ noise_pred_uncond
420
+ + self.guidance_scale *
421
+ (noise_pred_text_b - noise_pred_uncond)
422
+ )
423
+ + noise
424
+ )
425
+ kappa = (
426
+ torch.abs(dsigma)
427
+ * (noise_pred_text_b - noise_pred_text_o)
428
+ * (noise_pred_text_b + noise_pred_text_o)
429
+ ).sum((1, 2, 3)) - (
430
+ dx_ind * ((noise_pred_text_o - noise_pred_text_b))
431
+ ).sum(
432
+ (1, 2, 3)
433
+ )
434
+ kappa /= (
435
+ 2
436
+ * dsigma
437
+ * self.guidance_scale
438
+ * ((noise_pred_text_o - noise_pred_text_b) ** 2).sum((1, 2, 3))
439
+ )
440
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
441
+ (noise_pred_text_b - noise_pred_uncond)
442
+ + kappa[:, None, None, None]
443
+ * (noise_pred_text_o - noise_pred_text_b)
444
+ )
445
 
446
  if i < self.num_inference_steps - 1:
447
+ latents += 2 * dsigma * noise_pred + noise
448
  else:
449
  latents += dsigma * noise_pred
450
 
 
464
  Callable
465
 
466
  """
467
+ latents = latents / self.vae.config.scaling_factor
468
  latents = latents.to(torch.float32)
469
  with torch.no_grad():
470
  image = self.vae.decode(latents, return_dict=False)[0]