QHL067 commited on
Commit
20ddbb6
·
1 Parent(s): f9567e5
Files changed (1) hide show
  1. app.py +188 -80
app.py CHANGED
@@ -5,124 +5,231 @@ from absl import app
5
  from ml_collections import config_flags
6
  import os
7
 
8
- import ml_collections
9
  import torch
10
- from torch import multiprocessing as mp
11
- import torch.nn as nn
12
- import accelerate
13
- import utils
14
- import tempfile
15
- from absl import logging
16
- import builtins
17
- import einops
18
- import math
19
- import numpy as np
20
- import time
21
- from PIL import Image
22
  import random
23
 
24
- from diffusion.flow_matching import FlowMatching, ODEFlowMatchingSolver, ODEEulerFlowMatchingSolver
25
- from tools.clip_score import ClipSocre
 
 
 
 
 
 
 
 
26
  import libs.autoencoder
27
  from libs.clip import FrozenCLIPEmbedder
28
- from libs.t5 import T5Embedder
29
 
30
 
31
- def unpreprocess(x):
32
- x = 0.5 * (x + 1.)
33
- x.clamp_(0., 1.)
34
- return x
35
 
36
- def batch_decode(_z, decode, batch_size=10):
37
- """
38
- The VAE decoder requires large GPU memory. To run the interpolation model on GPUs with 24 GB or smaller RAM, you can use this code to reduce memory usage for the VAE.
39
- It works by splitting the input tensor into smaller chunks.
40
- """
 
 
 
 
 
 
 
 
 
 
 
41
  num_samples = _z.size(0)
42
  decoded_batches = []
43
 
44
  for i in range(0, num_samples, batch_size):
45
- batch = _z[i:i + batch_size]
46
  decoded_batch = decode(batch)
47
  decoded_batches.append(decoded_batch)
48
 
49
- image_unprocessed = torch.cat(decoded_batches, dim=0)
50
- return image_unprocessed
51
 
52
- def get_caption(llm, text_model, prompt_dict, batch_size):
53
-
54
  if batch_size == 3:
55
- # only addition or only subtraction
56
- assert len(prompt_dict) == 2
57
- _batch_con = list(prompt_dict.values()) + [' ']
58
  elif batch_size == 4:
59
- # addition and subtraction
60
- assert len(prompt_dict) == 3
61
- _batch_con = list(prompt_dict.values()) + [' ']
62
  elif batch_size >= 5:
63
- # linear interpolation
64
- assert len(prompt_dict) == 2
65
- _batch_con = [prompt_dict['prompt_1']] + [' '] * (batch_size-2) + [prompt_dict['prompt_2']]
 
 
66
 
67
  if llm == "clip":
68
- _latent, _latent_and_others = text_model.encode(_batch_con)
69
- _con = _latent_and_others['token_embedding'].detach()
70
  elif llm == "t5":
71
- _latent, _latent_and_others = text_model.get_text_embeddings(_batch_con)
72
- _con = (_latent_and_others['token_embedding'] * 10.0).detach()
73
  else:
74
- raise NotImplementedError
75
- _con_mask = _latent_and_others['token_mask'].detach()
76
- _batch_token = _latent_and_others['tokens'].detach()
77
- _batch_caption = _batch_con
78
- return (_con, _con_mask, _batch_token, _batch_caption)
79
 
80
- import spaces #[uncomment to use ZeroGPU]
81
- from diffusers import DiffusionPipeline
82
- import torch
 
 
83
 
84
- device = "cuda" if torch.cuda.is_available() else "cpu"
85
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
 
86
 
87
- if torch.cuda.is_available():
88
- torch_dtype = torch.float16
89
- else:
90
- torch_dtype = torch.float32
91
 
92
- # pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
93
- # pipe = pipe.to(device)
94
 
 
95
  MAX_SEED = np.iinfo(np.int32).max
96
- MAX_IMAGE_SIZE = 1024
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
 
99
  @spaces.GPU #[uncomment to use ZeroGPU]
100
  def infer(
101
  prompt1,
102
  prompt2,
103
- negative_prompt,
104
  seed,
105
  randomize_seed,
106
  guidance_scale,
107
  num_inference_steps,
 
 
108
  progress=gr.Progress(track_tqdm=True),
109
  ):
110
  if randomize_seed:
111
  seed = random.randint(0, MAX_SEED)
112
 
113
- generator = torch.Generator().manual_seed(seed)
 
 
114
 
115
- # image = pipe(
116
- # prompt=prompt,
117
- # negative_prompt=negative_prompt,
118
- # guidance_scale=guidance_scale,
119
- # num_inference_steps=num_inference_steps,
120
- # width=width,
121
- # height=height,
122
- # generator=generator,
123
- # ).images[0]
124
 
125
- # return image, seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
 
128
  # examples = [
@@ -171,13 +278,6 @@ with gr.Blocks(css=css) as demo:
171
  result = gr.Image(label="Result", show_label=False)
172
 
173
  with gr.Accordion("Advanced Settings", open=False):
174
- negative_prompt = gr.Text(
175
- label="Negative prompt",
176
- max_lines=1,
177
- placeholder="Enter a negative prompt",
178
- visible=False,
179
- )
180
-
181
  seed = gr.Slider(
182
  label="Seed",
183
  minimum=0,
@@ -205,6 +305,14 @@ with gr.Blocks(css=css) as demo:
205
  value=50, # Replace with defaults that work for your model
206
  )
207
 
 
 
 
 
 
 
 
 
208
  gr.Examples(examples=examples, inputs=[prompt1, prompt2])
209
  gr.on(
210
  triggers=[run_button.click, prompt1.submit, prompt2.submit],
@@ -212,11 +320,11 @@ with gr.Blocks(css=css) as demo:
212
  inputs=[
213
  prompt1,
214
  prompt2,
215
- negative_prompt,
216
  seed,
217
  randomize_seed,
218
  guidance_scale,
219
  num_inference_steps,
 
220
  ],
221
  outputs=[result, seed],
222
  )
 
5
  from ml_collections import config_flags
6
  import os
7
 
8
+ import spaces #[uncomment to use ZeroGPU]
9
  import torch
10
+
11
+
12
+ import os
 
 
 
 
 
 
 
 
 
13
  import random
14
 
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torchvision.utils import save_image
19
+
20
+ from absl import logging
21
+ import ml_collections
22
+
23
+ from diffusion.flow_matching import ODEEulerFlowMatchingSolver
24
+ import utils
25
  import libs.autoencoder
26
  from libs.clip import FrozenCLIPEmbedder
27
+ from configs import t2i_512px_clip_dimr
28
 
29
 
30
+ def unpreprocess(x: torch.Tensor) -> torch.Tensor:
31
+ x = 0.5 * (x + 1.0)
32
+ x.clamp_(0.0, 1.0)
33
+ return x
34
 
35
+ def cosine_similarity_torch(latent1: torch.Tensor, latent2: torch.Tensor) -> torch.Tensor:
36
+ latent1_flat = latent1.view(-1)
37
+ latent2_flat = latent2.view(-1)
38
+ cosine_similarity = F.cosine_similarity(
39
+ latent1_flat.unsqueeze(0), latent2_flat.unsqueeze(0), dim=1
40
+ )
41
+ return cosine_similarity
42
+
43
+ def kl_divergence(latent1: torch.Tensor, latent2: torch.Tensor) -> torch.Tensor:
44
+ latent1_prob = F.softmax(latent1, dim=-1)
45
+ latent2_prob = F.softmax(latent2, dim=-1)
46
+ latent1_log_prob = torch.log(latent1_prob)
47
+ kl_div = F.kl_div(latent1_log_prob, latent2_prob, reduction="batchmean")
48
+ return kl_div
49
+
50
+ def batch_decode(_z: torch.Tensor, decode, batch_size: int = 10) -> torch.Tensor:
51
  num_samples = _z.size(0)
52
  decoded_batches = []
53
 
54
  for i in range(0, num_samples, batch_size):
55
+ batch = _z[i : i + batch_size]
56
  decoded_batch = decode(batch)
57
  decoded_batches.append(decoded_batch)
58
 
59
+ return torch.cat(decoded_batches, dim=0)
 
60
 
61
+ def get_caption(llm: str, text_model, prompt_dict: dict, batch_size: int):
 
62
  if batch_size == 3:
63
+ # Only addition or only subtraction mode.
64
+ assert len(prompt_dict) == 2, "Expected 2 prompts for batch_size 3."
65
+ batch_prompts = list(prompt_dict.values()) + [" "]
66
  elif batch_size == 4:
67
+ # Addition and subtraction mode.
68
+ assert len(prompt_dict) == 3, "Expected 3 prompts for batch_size 4."
69
+ batch_prompts = list(prompt_dict.values()) + [" "]
70
  elif batch_size >= 5:
71
+ # Linear interpolation mode.
72
+ assert len(prompt_dict) == 2, "Expected 2 prompts for linear interpolation."
73
+ batch_prompts = [prompt_dict["prompt_1"]] + [" "] * (batch_size - 2) + [prompt_dict["prompt_2"]]
74
+ else:
75
+ raise ValueError(f"Unsupported batch_size: {batch_size}")
76
 
77
  if llm == "clip":
78
+ latent, latent_and_others = text_model.encode(batch_prompts)
79
+ context = latent_and_others["token_embedding"].detach()
80
  elif llm == "t5":
81
+ latent, latent_and_others = text_model.get_text_embeddings(batch_prompts)
82
+ context = (latent_and_others["token_embedding"] * 10.0).detach()
83
  else:
84
+ raise NotImplementedError(f"Language model {llm} not supported.")
 
 
 
 
85
 
86
+ token_mask = latent_and_others["token_mask"].detach()
87
+ tokens = latent_and_others["tokens"].detach()
88
+ captions = batch_prompts
89
+
90
+ return context, token_mask, tokens, captions
91
 
92
+ # Load configuration and initialize models.
93
+ config_dict = t2i_512px_clip_dimr.get_config()
94
+ config = ml_collections.ConfigDict(config_dict)
95
 
96
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
97
+ logging.info(f"Using device: {device}")
 
 
98
 
99
+ # Freeze configuration.
100
+ config = ml_collections.FrozenConfigDict(config)
101
 
102
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
103
  MAX_SEED = np.iinfo(np.int32).max
104
+ MAX_IMAGE_SIZE = 1024 # Currently not used.
105
+
106
+ # Load the main diffusion model.
107
+ nnet_path = os.path.join("..", "..", "ckpt", "released_model", "t2i_512px_clip_dimr.pth")
108
+ nnet = utils.get_nnet(**config.nnet)
109
+ nnet = nnet.to(device)
110
+ state_dict = torch.load(nnet_path, map_location=device)
111
+ nnet.load_state_dict(state_dict)
112
+ nnet.eval()
113
+
114
+ # Initialize text model.
115
+ llm = "clip"
116
+ clip = FrozenCLIPEmbedder()
117
+ clip.eval()
118
+ clip.to(device)
119
+
120
+ # Load autoencoder.
121
+ autoencoder = libs.autoencoder.get_model(**config.autoencoder)
122
+ autoencoder.to(device)
123
+
124
+
125
+ @torch.cuda.amp.autocast()
126
+ def encode(_batch: torch.Tensor) -> torch.Tensor:
127
+ """Encode a batch of images using the autoencoder."""
128
+ return autoencoder.encode(_batch)
129
+
130
+
131
+ @torch.cuda.amp.autocast()
132
+ def decode(_batch: torch.Tensor) -> torch.Tensor:
133
+ """Decode a batch of latent vectors using the autoencoder."""
134
+ return autoencoder.decode(_batch)
135
 
136
 
137
  @spaces.GPU #[uncomment to use ZeroGPU]
138
  def infer(
139
  prompt1,
140
  prompt2,
 
141
  seed,
142
  randomize_seed,
143
  guidance_scale,
144
  num_inference_steps,
145
+ num_of_interpolation,
146
+ save_gpu_memory=True,
147
  progress=gr.Progress(track_tqdm=True),
148
  ):
149
  if randomize_seed:
150
  seed = random.randint(0, MAX_SEED)
151
 
152
+ torch.manual_seed(seed)
153
+ if device.type == "cuda":
154
+ torch.cuda.manual_seed_all(seed)
155
 
156
+ # Only support interpolation in this implementation.
157
+ prompt_dict = {"prompt_1": prompt1, "prompt_2": prompt2}
158
+ for key, value in prompt_dict.items():
159
+ assert value is not None, f"{key} must not be None."
160
+ assert num_of_interpolation >= 5, "For linear interpolation, please sample at least five images."
 
 
 
 
161
 
162
+ # Get text embeddings and tokens.
163
+ _context, _token_mask, _token, _caption = get_caption(
164
+ llm, clip, prompt_dict=prompt_dict, batch_size=num_of_interpolation
165
+ )
166
+
167
+ with torch.no_grad():
168
+ _z_gaussian = torch.randn(num_of_interpolation, *config.z_shape, device=device)
169
+ _z_x0, _mu, _log_var = nnet(
170
+ _context, text_encoder=True, shape=_z_gaussian.shape, mask=_token_mask
171
+ )
172
+ _z_init = _z_x0.reshape(_z_gaussian.shape)
173
+
174
+ # Prepare the initial latent representations based on the number of interpolations.
175
+ if num_of_interpolation == 3:
176
+ # Addition or subtraction mode.
177
+ if config.prompt_a is not None:
178
+ assert config.prompt_s is None, "Only one of prompt_a or prompt_s should be provided."
179
+ z_init_temp = _z_init[0] + _z_init[1]
180
+ elif config.prompt_s is not None:
181
+ assert config.prompt_a is None, "Only one of prompt_a or prompt_s should be provided."
182
+ z_init_temp = _z_init[0] - _z_init[1]
183
+ else:
184
+ raise NotImplementedError("Either prompt_a or prompt_s must be provided for 3-sample mode.")
185
+ mean = z_init_temp.mean()
186
+ std = z_init_temp.std()
187
+ _z_init[2] = (z_init_temp - mean) / std
188
+
189
+ elif num_of_interpolation == 4:
190
+ z_init_temp = _z_init[0] + _z_init[1] - _z_init[2]
191
+ mean = z_init_temp.mean()
192
+ std = z_init_temp.std()
193
+ _z_init[3] = (z_init_temp - mean) / std
194
+
195
+ elif num_of_interpolation >= 5:
196
+ tensor_a = _z_init[0]
197
+ tensor_b = _z_init[-1]
198
+ num_interpolations = num_of_interpolation - 2
199
+ interpolations = [
200
+ tensor_a + (tensor_b - tensor_a) * (i / (num_interpolations + 1))
201
+ for i in range(1, num_interpolations + 1)
202
+ ]
203
+ _z_init = torch.stack([tensor_a] + interpolations + [tensor_b], dim=0)
204
+
205
+ else:
206
+ raise ValueError("Unsupported number of interpolations.")
207
+
208
+ assert guidance_scale > 1, "Guidance scale must be greater than 1."
209
+
210
+ has_null_indicator = hasattr(config.nnet.model_args, "cfg_indicator")
211
+ ode_solver = ODEEulerFlowMatchingSolver(
212
+ nnet,
213
+ bdv_model_fn=None,
214
+ step_size_type="step_in_dsigma",
215
+ guidance_scale=guidance_scale,
216
+ )
217
+ _z, _ = ode_solver.sample(
218
+ x_T=_z_init,
219
+ batch_size=num_of_interpolation,
220
+ sample_steps=num_inference_steps,
221
+ unconditional_guidance_scale=guidance_scale,
222
+ has_null_indicator=has_null_indicator,
223
+ )
224
+
225
+ if save_gpu_memory:
226
+ image_unprocessed = batch_decode(_z, decode)
227
+ else:
228
+ image_unprocessed = decode(_z)
229
+
230
+ samples = unpreprocess(image_unprocessed).contiguous()[0]
231
+
232
+ return samples, seed
233
 
234
 
235
  # examples = [
 
278
  result = gr.Image(label="Result", show_label=False)
279
 
280
  with gr.Accordion("Advanced Settings", open=False):
 
 
 
 
 
 
 
281
  seed = gr.Slider(
282
  label="Seed",
283
  minimum=0,
 
305
  value=50, # Replace with defaults that work for your model
306
  )
307
 
308
+ num_of_interpolation = gr.Slider(
309
+ label="Number of images for interpolation",
310
+ minimum=5,
311
+ maximum=50,
312
+ step=1,
313
+ value=10, # Replace with defaults that work for your model
314
+ )
315
+
316
  gr.Examples(examples=examples, inputs=[prompt1, prompt2])
317
  gr.on(
318
  triggers=[run_button.click, prompt1.submit, prompt2.submit],
 
320
  inputs=[
321
  prompt1,
322
  prompt2,
 
323
  seed,
324
  randomize_seed,
325
  guidance_scale,
326
  num_inference_steps,
327
+ num_of_interpolation,
328
  ],
329
  outputs=[result, seed],
330
  )