mskrt commited on
Commit
045bb1e
verified
1 Parent(s): cc7a552

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +83 -44
pipeline.py CHANGED
@@ -38,8 +38,10 @@ def int_beta(t):
38
  t :
39
  t
40
  """
41
- a, b = get_scaled_coeffs()
42
- return ((a+b*t)**3-a**3)/(3*b)
 
 
43
  def sigma(t):
44
  """sigma.
45
 
@@ -48,7 +50,9 @@ def sigma(t):
48
  t :
49
  t
50
  """
51
- return torch.expm1(int_beta(t))**0.5
 
 
52
  def sigma_orig(t):
53
  """sigma_orig.
54
 
@@ -57,13 +61,13 @@ def sigma_orig(t):
57
  t :
58
  t
59
  """
60
- return (-torch.expm1(-int_beta(t)))**0.5
 
61
 
62
  class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin):
63
  """SuperDiffSDXLPipeline."""
64
 
65
  def __init__(self, unet: Callable, vae: Callable, text_encoder: Callable, text_encoder_2: Callable, tokenizer: Callable, tokenizer_2: Callable) -> None:
66
-
67
  """__init__.
68
 
69
  Parameters
@@ -87,16 +91,16 @@ class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin):
87
 
88
  """
89
  super().__init__()
90
- device = "cuda" if torch.cuda.is_available() else "cpu"
91
- dtype=torch.float16
92
 
93
  vae.to(device)
94
  unet.to(device)
95
  text_encoder.to(device)
96
  text_encoder_2.to(device)
97
 
98
- self.register_modules(unet=unet,
99
- vae=vae,
100
  text_encoder=text_encoder,
101
  text_encoder_2=text_encoder_2,
102
  tokenizer=tokenizer,
@@ -119,34 +123,50 @@ class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin):
119
  width :
120
  width
121
  """
122
- text_input = self.tokenizer(prompt_o* batch_size, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
123
- text_input_2 = self.tokenizer_2(prompt_o* batch_size, padding="max_length", max_length=self.tokenizer_2.model_max_length, truncation=True, return_tensors="pt")
 
 
124
  with torch.no_grad():
125
- text_embeddings = self.text_encoder(text_input.input_ids.to(self.device), output_hidden_states=True)
126
- text_embeddings_2 = self.text_encoder_2(text_input_2.input_ids.to(self.device), output_hidden_states=True)
127
- prompt_embeds_o = torch.concat((text_embeddings.hidden_states[-2], text_embeddings_2.hidden_states[-2]), dim=-1)
 
 
 
128
  pooled_prompt_embeds_o = text_embeddings_2[0]
129
  negative_prompt_embeds = torch.zeros_like(prompt_embeds_o)
130
- negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds_o)
131
-
132
- text_input = self.tokenizer(prompt_b* batch_size, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
133
- text_input_2 = self.tokenizer_2(prompt_b* batch_size, padding="max_length", max_length=self.tokenizer_2.model_max_length, truncation=True, return_tensors="pt")
 
 
 
134
  with torch.no_grad():
135
- text_embeddings = self.text_encoder(text_input.input_ids.to(self.device), output_hidden_states=True)
136
- text_embeddings_2 = self.text_encoder_2(text_input_2.input_ids.to(self.device), output_hidden_states=True)
137
- prompt_embeds_b = torch.concat((text_embeddings.hidden_states[-2], text_embeddings_2.hidden_states[-2]), dim=-1)
 
 
 
138
  pooled_prompt_embeds_b = text_embeddings_2[0]
139
- add_time_ids_o = torch.tensor([(height,width,0,0,height,width)])
140
- add_time_ids_b = torch.tensor([(height,width,0,0,height,width)])
141
- negative_add_time_ids = torch.tensor([(height,width,0,0,height,width)])
142
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds_o, prompt_embeds_b], dim=0)
143
- add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds_o, pooled_prompt_embeds_b], dim=0)
144
- add_time_ids = torch.cat([negative_add_time_ids, add_time_ids_o, add_time_ids_b], dim=0)
145
-
 
 
 
 
146
  prompt_embeds = prompt_embeds.to(self.device)
147
  add_text_embeds = add_text_embeds.to(self.device)
148
  add_time_ids = add_time_ids.to(self.device).repeat(batch_size, 1)
149
- added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
 
150
  return prompt_embeds, added_cond_kwargs
151
 
152
  @torch.no_grad
@@ -217,6 +237,15 @@ class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin):
217
  def v(_x, _e): return self.model(
218
  """v.
219
 
 
 
 
 
 
 
 
 
 
220
  Parameters
221
  ----------
222
  _x :
@@ -280,8 +309,10 @@ class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin):
280
  self.seed
281
  ) # Seed generator to create the initial latent noise
282
 
283
- latents = torch.randn((batch_size, self.unet.in_channels, height // 8, width // 8), generator=self.generator, dtype=self.dtype, device=self.device,)
284
- prompt_embeds, added_cond_kwargs = self.prepare_prompt_input(prompt_1, prompt_2, batch_size, height, width)
 
 
285
 
286
  return {
287
  "latents": latents,
@@ -317,18 +348,26 @@ class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin):
317
  dsigma = sigma(t-dt) - sigma_t
318
  latent_model_input /= (sigma_t**2+1)**0.5
319
  with torch.no_grad():
320
- noise_pred = self.unet(latent_model_input, t*train_number_steps, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs, return_dict=False)[0]
321
-
322
- noise_pred_uncond, noise_pred_text_o, noise_pred_text_b = noise_pred.chunk(3)
323
-
 
 
324
  # noise = torch.sqrt(2*torch.abs(dsigma)*sigma_t)*torch.randn_like(latents)
325
- noise = torch.sqrt(2*torch.abs(dsigma)*sigma_t)*torch.empty_like(latents, device=self.device).normal_(generator=self.generator)
326
-
327
- dx_ind = 2*dsigma*(noise_pred_uncond + self.guidance_scale*(noise_pred_text_b - noise_pred_uncond)) + noise
328
- kappa = (torch.abs(dsigma)*(noise_pred_text_b-noise_pred_text_o)*(noise_pred_text_b+noise_pred_text_o)).sum((1,2,3))-(dx_ind*((noise_pred_text_o-noise_pred_text_b))).sum((1,2,3))
329
- kappa /= 2*dsigma*self.guidance_scale*((noise_pred_text_o-noise_pred_text_b)**2).sum((1,2,3))
330
- noise_pred = noise_pred_uncond + self.guidance_scale*((noise_pred_text_b - noise_pred_uncond) + kappa[:,None,None,None]*(noise_pred_text_o-noise_pred_text_b))
331
-
 
 
 
 
 
 
332
  if i < self.num_inference_steps - 1:
333
  latents += 2*dsigma * noise_pred + noise
334
  else:
@@ -354,7 +393,7 @@ class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin):
354
  latents = latents.to(torch.float32)
355
  with torch.no_grad():
356
  image = self.vae.decode(latents, return_dict=False)[0]
357
-
358
  image = (image / 2 + 0.5).clamp(0, 1)
359
  image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
360
  images = (image * 255).round().astype("uint8")
@@ -389,7 +428,7 @@ class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin):
389
  height
390
  width : int
391
  width
392
- guidance_scale : int
393
  guidance_scale
394
 
395
  Returns
 
38
  t :
39
  t
40
  """
41
+ a, b = get_scaled_coeffs()
42
+ return ((a+b*t)**3-a**3)/(3*b)
43
+
44
+
45
  def sigma(t):
46
  """sigma.
47
 
 
50
  t :
51
  t
52
  """
53
+ return torch.expm1(int_beta(t))**0.5
54
+
55
+
56
  def sigma_orig(t):
57
  """sigma_orig.
58
 
 
61
  t :
62
  t
63
  """
64
+ return (-torch.expm1(-int_beta(t)))**0.5
65
+
66
 
67
  class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin):
68
  """SuperDiffSDXLPipeline."""
69
 
70
  def __init__(self, unet: Callable, vae: Callable, text_encoder: Callable, text_encoder_2: Callable, tokenizer: Callable, tokenizer_2: Callable) -> None:
 
71
  """__init__.
72
 
73
  Parameters
 
91
 
92
  """
93
  super().__init__()
94
+ device = "cuda" if torch.cuda.is_available() else "cpu"
95
+ dtype = torch.float16
96
 
97
  vae.to(device)
98
  unet.to(device)
99
  text_encoder.to(device)
100
  text_encoder_2.to(device)
101
 
102
+ self.register_modules(unet=unet,
103
+ vae=vae,
104
  text_encoder=text_encoder,
105
  text_encoder_2=text_encoder_2,
106
  tokenizer=tokenizer,
 
123
  width :
124
  width
125
  """
126
+ text_input = self.tokenizer(prompt_o * batch_size, padding="max_length",
127
+ max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
128
+ text_input_2 = self.tokenizer_2(prompt_o * batch_size, padding="max_length",
129
+ max_length=self.tokenizer_2.model_max_length, truncation=True, return_tensors="pt")
130
  with torch.no_grad():
131
+ text_embeddings = self.text_encoder(
132
+ text_input.input_ids.to(self.device), output_hidden_states=True)
133
+ text_embeddings_2 = self.text_encoder_2(
134
+ text_input_2.input_ids.to(self.device), output_hidden_states=True)
135
+ prompt_embeds_o = torch.concat(
136
+ (text_embeddings.hidden_states[-2], text_embeddings_2.hidden_states[-2]), dim=-1)
137
  pooled_prompt_embeds_o = text_embeddings_2[0]
138
  negative_prompt_embeds = torch.zeros_like(prompt_embeds_o)
139
+ negative_pooled_prompt_embeds = torch.zeros_like(
140
+ pooled_prompt_embeds_o)
141
+
142
+ text_input = self.tokenizer(prompt_b * batch_size, padding="max_length",
143
+ max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
144
+ text_input_2 = self.tokenizer_2(prompt_b * batch_size, padding="max_length",
145
+ max_length=self.tokenizer_2.model_max_length, truncation=True, return_tensors="pt")
146
  with torch.no_grad():
147
+ text_embeddings = self.text_encoder(
148
+ text_input.input_ids.to(self.device), output_hidden_states=True)
149
+ text_embeddings_2 = self.text_encoder_2(
150
+ text_input_2.input_ids.to(self.device), output_hidden_states=True)
151
+ prompt_embeds_b = torch.concat(
152
+ (text_embeddings.hidden_states[-2], text_embeddings_2.hidden_states[-2]), dim=-1)
153
  pooled_prompt_embeds_b = text_embeddings_2[0]
154
+ add_time_ids_o = torch.tensor([(height, width, 0, 0, height, width)])
155
+ add_time_ids_b = torch.tensor([(height, width, 0, 0, height, width)])
156
+ negative_add_time_ids = torch.tensor(
157
+ [(height, width, 0, 0, height, width)])
158
+ prompt_embeds = torch.cat(
159
+ [negative_prompt_embeds, prompt_embeds_o, prompt_embeds_b], dim=0)
160
+ add_text_embeds = torch.cat(
161
+ [negative_pooled_prompt_embeds, pooled_prompt_embeds_o, pooled_prompt_embeds_b], dim=0)
162
+ add_time_ids = torch.cat(
163
+ [negative_add_time_ids, add_time_ids_o, add_time_ids_b], dim=0)
164
+
165
  prompt_embeds = prompt_embeds.to(self.device)
166
  add_text_embeds = add_text_embeds.to(self.device)
167
  add_time_ids = add_time_ids.to(self.device).repeat(batch_size, 1)
168
+ added_cond_kwargs = {
169
+ "text_embeds": add_text_embeds, "time_ids": add_time_ids}
170
  return prompt_embeds, added_cond_kwargs
171
 
172
  @torch.no_grad
 
237
  def v(_x, _e): return self.model(
238
  """v.
239
 
240
+ Parameters
241
+ ----------
242
+ _x :
243
+ _x
244
+ _e :
245
+ _e
246
+ """
247
+ """v.
248
+
249
  Parameters
250
  ----------
251
  _x :
 
309
  self.seed
310
  ) # Seed generator to create the initial latent noise
311
 
312
+ latents = torch.randn((batch_size, self.unet.in_channels, height // 8, width // 8),
313
+ generator=self.generator, dtype=self.dtype, device=self.device,)
314
+ prompt_embeds, added_cond_kwargs = self.prepare_prompt_input(
315
+ prompt_1, prompt_2, batch_size, height, width)
316
 
317
  return {
318
  "latents": latents,
 
348
  dsigma = sigma(t-dt) - sigma_t
349
  latent_model_input /= (sigma_t**2+1)**0.5
350
  with torch.no_grad():
351
+ noise_pred = self.unet(latent_model_input, t*train_number_steps, encoder_hidden_states=prompt_embeds,
352
+ added_cond_kwargs=added_cond_kwargs, return_dict=False)[0]
353
+
354
+ noise_pred_uncond, noise_pred_text_o, noise_pred_text_b = noise_pred.chunk(
355
+ 3)
356
+
357
  # noise = torch.sqrt(2*torch.abs(dsigma)*sigma_t)*torch.randn_like(latents)
358
+ noise = torch.sqrt(2*torch.abs(dsigma)*sigma_t)*torch.empty_like(
359
+ latents, device=self.device).normal_(generator=self.generator)
360
+
361
+ dx_ind = 2*dsigma*(noise_pred_uncond + self.guidance_scale *
362
+ (noise_pred_text_b - noise_pred_uncond)) + noise
363
+ kappa = (torch.abs(dsigma)*(noise_pred_text_b-noise_pred_text_o)*(noise_pred_text_b+noise_pred_text_o)
364
+ ).sum((1, 2, 3))-(dx_ind*((noise_pred_text_o-noise_pred_text_b))).sum((1, 2, 3))
365
+ kappa /= 2*dsigma*self.guidance_scale * \
366
+ ((noise_pred_text_o-noise_pred_text_b)**2).sum((1, 2, 3))
367
+ noise_pred = noise_pred_uncond + self.guidance_scale * \
368
+ ((noise_pred_text_b - noise_pred_uncond) +
369
+ kappa[:, None, None, None]*(noise_pred_text_o-noise_pred_text_b))
370
+
371
  if i < self.num_inference_steps - 1:
372
  latents += 2*dsigma * noise_pred + noise
373
  else:
 
393
  latents = latents.to(torch.float32)
394
  with torch.no_grad():
395
  image = self.vae.decode(latents, return_dict=False)[0]
396
+
397
  image = (image / 2 + 0.5).clamp(0, 1)
398
  image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
399
  images = (image * 255).round().astype("uint8")
 
428
  height
429
  width : int
430
  width
431
+ guidance_scale : float
432
  guidance_scale
433
 
434
  Returns