QHL067 commited on
Commit
7a6b9c3
·
1 Parent(s): 09d905b
Files changed (1) hide show
  1. app.py +245 -403
app.py CHANGED
@@ -1,402 +1,249 @@
1
- # import gradio as gr
2
 
3
- # from absl import flags
4
- # from absl import app
5
- # from ml_collections import config_flags
6
- # import os
7
 
8
- # import spaces #[uncomment to use ZeroGPU]
9
- # import torch
10
 
11
 
12
- # import os
13
- # import random
14
 
15
- # import numpy as np
16
- # import torch
17
- # import torch.nn.functional as F
18
- # from torchvision.utils import save_image
19
- # from huggingface_hub import hf_hub_download
20
 
21
- # from absl import logging
22
- # import ml_collections
23
 
24
- # from diffusion.flow_matching import ODEEulerFlowMatchingSolver
25
- # import utils
26
- # import libs.autoencoder
27
- # from libs.clip import FrozenCLIPEmbedder
28
- # from configs import t2i_512px_clip_dimr
29
 
30
 
31
- # def unpreprocess(x: torch.Tensor) -> torch.Tensor:
32
- # x = 0.5 * (x + 1.0)
33
- # x.clamp_(0.0, 1.0)
34
- # return x
35
 
36
- # def cosine_similarity_torch(latent1: torch.Tensor, latent2: torch.Tensor) -> torch.Tensor:
37
- # latent1_flat = latent1.view(-1)
38
- # latent2_flat = latent2.view(-1)
39
- # cosine_similarity = F.cosine_similarity(
40
- # latent1_flat.unsqueeze(0), latent2_flat.unsqueeze(0), dim=1
41
- # )
42
- # return cosine_similarity
43
-
44
- # def kl_divergence(latent1: torch.Tensor, latent2: torch.Tensor) -> torch.Tensor:
45
- # latent1_prob = F.softmax(latent1, dim=-1)
46
- # latent2_prob = F.softmax(latent2, dim=-1)
47
- # latent1_log_prob = torch.log(latent1_prob)
48
- # kl_div = F.kl_div(latent1_log_prob, latent2_prob, reduction="batchmean")
49
- # return kl_div
50
-
51
- # def batch_decode(_z: torch.Tensor, decode, batch_size: int = 10) -> torch.Tensor:
52
- # num_samples = _z.size(0)
53
- # decoded_batches = []
54
-
55
- # for i in range(0, num_samples, batch_size):
56
- # batch = _z[i : i + batch_size]
57
- # decoded_batch = decode(batch)
58
- # decoded_batches.append(decoded_batch)
59
-
60
- # return torch.cat(decoded_batches, dim=0)
61
-
62
- # def get_caption(llm: str, text_model, prompt_dict: dict, batch_size: int):
63
- # if batch_size == 3:
64
- # # Only addition or only subtraction mode.
65
- # assert len(prompt_dict) == 2, "Expected 2 prompts for batch_size 3."
66
- # batch_prompts = list(prompt_dict.values()) + [" "]
67
- # elif batch_size == 4:
68
- # # Addition and subtraction mode.
69
- # assert len(prompt_dict) == 3, "Expected 3 prompts for batch_size 4."
70
- # batch_prompts = list(prompt_dict.values()) + [" "]
71
- # elif batch_size >= 5:
72
- # # Linear interpolation mode.
73
- # assert len(prompt_dict) == 2, "Expected 2 prompts for linear interpolation."
74
- # batch_prompts = [prompt_dict["prompt_1"]] + [" "] * (batch_size - 2) + [prompt_dict["prompt_2"]]
75
- # else:
76
- # raise ValueError(f"Unsupported batch_size: {batch_size}")
77
-
78
- # if llm == "clip":
79
- # latent, latent_and_others = text_model.encode(batch_prompts)
80
- # context = latent_and_others["token_embedding"].detach()
81
- # elif llm == "t5":
82
- # latent, latent_and_others = text_model.get_text_embeddings(batch_prompts)
83
- # context = (latent_and_others["token_embedding"] * 10.0).detach()
84
- # else:
85
- # raise NotImplementedError(f"Language model {llm} not supported.")
86
-
87
- # token_mask = latent_and_others["token_mask"].detach()
88
- # tokens = latent_and_others["tokens"].detach()
89
- # captions = batch_prompts
90
-
91
- # return context, token_mask, tokens, captions
92
-
93
- # # Load configuration and initialize models.
94
- # config_dict = t2i_512px_clip_dimr.get_config()
95
- # config = ml_collections.ConfigDict(config_dict)
96
-
97
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
98
- # logging.info(f"Using device: {device}")
99
-
100
- # # Freeze configuration.
101
- # config = ml_collections.FrozenConfigDict(config)
102
-
103
- # torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
104
- # MAX_SEED = np.iinfo(np.int32).max
105
- # MAX_IMAGE_SIZE = 1024 # Currently not used.
106
-
107
- # # Load the main diffusion model.
108
- # repo_id = "QHL067/CrossFlow"
109
- # filename = "pretrained_models/t2i_512px_clip_dimr.pth"
110
- # checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
111
- # nnet = utils.get_nnet(**config.nnet)
112
- # nnet = nnet.to(device)
113
- # state_dict = torch.load(checkpoint_path, map_location=device)
114
- # nnet.load_state_dict(state_dict)
115
- # nnet.eval()
116
-
117
- # # Initialize text model.
118
- # llm = "clip"
119
- # clip = FrozenCLIPEmbedder()
120
- # clip.eval()
121
- # clip.to(device)
122
-
123
- # # Load autoencoder.
124
- # autoencoder = libs.autoencoder.get_model(**config.autoencoder)
125
- # autoencoder.to(device)
126
-
127
-
128
- # @torch.cuda.amp.autocast()
129
- # def encode(_batch: torch.Tensor) -> torch.Tensor:
130
- # """Encode a batch of images using the autoencoder."""
131
- # return autoencoder.encode(_batch)
132
-
133
-
134
- # @torch.cuda.amp.autocast()
135
- # def decode(_batch: torch.Tensor) -> torch.Tensor:
136
- # """Decode a batch of latent vectors using the autoencoder."""
137
- # return autoencoder.decode(_batch)
138
-
139
-
140
- # @spaces.GPU #[uncomment to use ZeroGPU]
141
- # def infer(
142
- # prompt1,
143
- # prompt2,
144
- # seed,
145
- # randomize_seed,
146
- # guidance_scale,
147
- # num_inference_steps,
148
- # num_of_interpolation,
149
- # save_gpu_memory=True,
150
- # progress=gr.Progress(track_tqdm=True),
151
- # ):
152
- # if randomize_seed:
153
- # seed = random.randint(0, MAX_SEED)
154
-
155
- # torch.manual_seed(seed)
156
- # if device.type == "cuda":
157
- # torch.cuda.manual_seed_all(seed)
158
-
159
- # # Only support interpolation in this implementation.
160
- # prompt_dict = {"prompt_1": prompt1, "prompt_2": prompt2}
161
- # for key, value in prompt_dict.items():
162
- # assert value is not None, f"{key} must not be None."
163
- # assert num_of_interpolation >= 5, "For linear interpolation, please sample at least five images."
164
-
165
- # # Get text embeddings and tokens.
166
- # _context, _token_mask, _token, _caption = get_caption(
167
- # llm, clip, prompt_dict=prompt_dict, batch_size=num_of_interpolation
168
- # )
169
-
170
- # with torch.no_grad():
171
- # _z_gaussian = torch.randn(num_of_interpolation, *config.z_shape, device=device)
172
- # _z_x0, _mu, _log_var = nnet(
173
- # _context, text_encoder=True, shape=_z_gaussian.shape, mask=_token_mask
174
- # )
175
- # _z_init = _z_x0.reshape(_z_gaussian.shape)
176
-
177
- # # Prepare the initial latent representations based on the number of interpolations.
178
- # if num_of_interpolation == 3:
179
- # # Addition or subtraction mode.
180
- # if config.prompt_a is not None:
181
- # assert config.prompt_s is None, "Only one of prompt_a or prompt_s should be provided."
182
- # z_init_temp = _z_init[0] + _z_init[1]
183
- # elif config.prompt_s is not None:
184
- # assert config.prompt_a is None, "Only one of prompt_a or prompt_s should be provided."
185
- # z_init_temp = _z_init[0] - _z_init[1]
186
- # else:
187
- # raise NotImplementedError("Either prompt_a or prompt_s must be provided for 3-sample mode.")
188
- # mean = z_init_temp.mean()
189
- # std = z_init_temp.std()
190
- # _z_init[2] = (z_init_temp - mean) / std
191
-
192
- # elif num_of_interpolation == 4:
193
- # z_init_temp = _z_init[0] + _z_init[1] - _z_init[2]
194
- # mean = z_init_temp.mean()
195
- # std = z_init_temp.std()
196
- # _z_init[3] = (z_init_temp - mean) / std
197
-
198
- # elif num_of_interpolation >= 5:
199
- # tensor_a = _z_init[0]
200
- # tensor_b = _z_init[-1]
201
- # num_interpolations = num_of_interpolation - 2
202
- # interpolations = [
203
- # tensor_a + (tensor_b - tensor_a) * (i / (num_interpolations + 1))
204
- # for i in range(1, num_interpolations + 1)
205
- # ]
206
- # _z_init = torch.stack([tensor_a] + interpolations + [tensor_b], dim=0)
207
-
208
- # else:
209
- # raise ValueError("Unsupported number of interpolations.")
210
-
211
- # assert guidance_scale > 1, "Guidance scale must be greater than 1."
212
-
213
- # has_null_indicator = hasattr(config.nnet.model_args, "cfg_indicator")
214
- # ode_solver = ODEEulerFlowMatchingSolver(
215
- # nnet,
216
- # bdv_model_fn=None,
217
- # step_size_type="step_in_dsigma",
218
- # guidance_scale=guidance_scale,
219
- # )
220
- # _z, _ = ode_solver.sample(
221
- # x_T=_z_init,
222
- # batch_size=num_of_interpolation,
223
- # sample_steps=num_inference_steps,
224
- # unconditional_guidance_scale=guidance_scale,
225
- # has_null_indicator=has_null_indicator,
226
- # )
227
-
228
- # if save_gpu_memory:
229
- # image_unprocessed = batch_decode(_z, decode)
230
- # else:
231
- # image_unprocessed = decode(_z)
232
-
233
- # samples = unpreprocess(image_unprocessed).contiguous()[0]
234
-
235
- # # return samples, seed
236
- # return seed
237
-
238
-
239
- # # examples = [
240
- # # "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
241
- # # "An astronaut riding a green horse",
242
- # # "A delicious ceviche cheesecake slice",
243
- # # ]
244
 
245
- # examples = [
246
- # ["A dog cooking dinner in the kitchen", "An orange cat wearing sunglasses on a ship"],
247
- # ]
 
 
 
 
 
 
248
 
249
- # css = """
250
- # #col-container {
251
- # margin: 0 auto;
252
- # max-width: 640px;
253
- # }
254
- # """
255
-
256
- # with gr.Blocks(css=css) as demo:
257
- # with gr.Column(elem_id="col-container"):
258
- # gr.Markdown(" # CrossFlow")
259
- # gr.Markdown(" CrossFlow directly transforms text representations into images for text-to-image generation, enabling interpolation in the input text latent space.")
260
-
261
- # with gr.Row():
262
- # prompt1 = gr.Text(
263
- # label="Prompt_1",
264
- # show_label=False,
265
- # max_lines=1,
266
- # placeholder="Enter your prompt for the first image",
267
- # container=False,
268
- # )
269
-
270
- # with gr.Row():
271
- # prompt2 = gr.Text(
272
- # label="Prompt_2",
273
- # show_label=False,
274
- # max_lines=1,
275
- # placeholder="Enter your prompt for the second image",
276
- # container=False,
277
- # )
278
-
279
- # with gr.Row():
280
- # run_button = gr.Button("Run", scale=0, variant="primary")
281
-
282
- # result = gr.Image(label="Result", show_label=False)
283
-
284
- # with gr.Accordion("Advanced Settings", open=False):
285
- # seed = gr.Slider(
286
- # label="Seed",
287
- # minimum=0,
288
- # maximum=MAX_SEED,
289
- # step=1,
290
- # value=0,
291
- # )
292
-
293
- # randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
294
-
295
- # with gr.Row():
296
- # guidance_scale = gr.Slider(
297
- # label="Guidance scale",
298
- # minimum=0.0,
299
- # maximum=10.0,
300
- # step=0.1,
301
- # value=7.0, # Replace with defaults that work for your model
302
- # )
303
- # with gr.Row():
304
- # num_inference_steps = gr.Slider(
305
- # label="Number of inference steps",
306
- # minimum=1,
307
- # maximum=50,
308
- # step=1,
309
- # value=50, # Replace with defaults that work for your model
310
- # )
311
- # with gr.Row():
312
- # num_of_interpolation = gr.Slider(
313
- # label="Number of images for interpolation",
314
- # minimum=5,
315
- # maximum=50,
316
- # step=1,
317
- # value=10, # Replace with defaults that work for your model
318
- # )
319
-
320
- # gr.Examples(examples=examples, inputs=[prompt1, prompt2])
321
- # gr.on(
322
- # triggers=[run_button.click, prompt1.submit, prompt2.submit],
323
- # fn=infer,
324
- # inputs=[
325
- # prompt1,
326
- # prompt2,
327
- # seed,
328
- # randomize_seed,
329
- # guidance_scale,
330
- # num_inference_steps,
331
- # num_of_interpolation,
332
- # ],
333
- # # outputs=[result, seed],
334
- # outputs=[seed],
335
- # )
336
-
337
- # if __name__ == "__main__":
338
- # demo.launch()
339
 
340
- import gradio as gr
341
- import numpy as np
342
- import random
343
 
344
- # import spaces #[uncomment to use ZeroGPU]
345
- from diffusers import DiffusionPipeline
346
- import torch
347
 
348
- device = "cuda" if torch.cuda.is_available() else "cpu"
349
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
 
 
350
 
351
- if torch.cuda.is_available():
352
- torch_dtype = torch.float16
353
- else:
354
- torch_dtype = torch.float32
355
 
356
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
357
- pipe = pipe.to(device)
358
-
359
- MAX_SEED = np.iinfo(np.int32).max
360
- MAX_IMAGE_SIZE = 1024
361
 
362
 
363
- # @spaces.GPU #[uncomment to use ZeroGPU]
364
  def infer(
365
- prompt,
366
- negative_prompt,
367
  seed,
368
  randomize_seed,
369
- width,
370
- height,
371
  guidance_scale,
372
  num_inference_steps,
 
 
373
  progress=gr.Progress(track_tqdm=True),
374
  ):
375
  if randomize_seed:
376
  seed = random.randint(0, MAX_SEED)
377
 
378
- generator = torch.Generator().manual_seed(seed)
 
 
379
 
380
- image = pipe(
381
- prompt=prompt,
382
- negative_prompt=negative_prompt,
383
- guidance_scale=guidance_scale,
384
- num_inference_steps=num_inference_steps,
385
- width=width,
386
- height=height,
387
- generator=generator,
388
- ).images[0]
389
 
390
- print('image.shape')
391
- print(image.shape)
 
 
392
 
393
- return image, seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
 
395
 
 
 
 
 
 
 
396
  examples = [
397
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
398
- "An astronaut riding a green horse",
399
- "A delicious ceviche cheesecake slice",
400
  ]
401
 
402
  css = """
@@ -408,29 +255,33 @@ css = """
408
 
409
  with gr.Blocks(css=css) as demo:
410
  with gr.Column(elem_id="col-container"):
411
- gr.Markdown(" # Text-to-Image Gradio Template")
 
412
 
413
  with gr.Row():
414
- prompt = gr.Text(
415
- label="Prompt",
416
  show_label=False,
417
  max_lines=1,
418
- placeholder="Enter your prompt",
 
 
 
 
 
 
 
 
 
419
  container=False,
420
  )
421
 
 
422
  run_button = gr.Button("Run", scale=0, variant="primary")
423
 
424
  result = gr.Image(label="Result", show_label=False)
425
 
426
  with gr.Accordion("Advanced Settings", open=False):
427
- negative_prompt = gr.Text(
428
- label="Negative prompt",
429
- max_lines=1,
430
- placeholder="Enter a negative prompt",
431
- visible=False,
432
- )
433
-
434
  seed = gr.Slider(
435
  label="Seed",
436
  minimum=0,
@@ -441,56 +292,47 @@ with gr.Blocks(css=css) as demo:
441
 
442
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
443
 
444
- with gr.Row():
445
- width = gr.Slider(
446
- label="Width",
447
- minimum=256,
448
- maximum=MAX_IMAGE_SIZE,
449
- step=32,
450
- value=1024, # Replace with defaults that work for your model
451
- )
452
-
453
- height = gr.Slider(
454
- label="Height",
455
- minimum=256,
456
- maximum=MAX_IMAGE_SIZE,
457
- step=32,
458
- value=1024, # Replace with defaults that work for your model
459
- )
460
-
461
  with gr.Row():
462
  guidance_scale = gr.Slider(
463
  label="Guidance scale",
464
  minimum=0.0,
465
  maximum=10.0,
466
  step=0.1,
467
- value=0.0, # Replace with defaults that work for your model
468
  )
469
-
470
  num_inference_steps = gr.Slider(
471
  label="Number of inference steps",
472
  minimum=1,
473
  maximum=50,
474
  step=1,
475
- value=2, # Replace with defaults that work for your model
 
 
 
 
 
 
 
 
476
  )
477
 
478
- gr.Examples(examples=examples, inputs=[prompt])
479
  gr.on(
480
- triggers=[run_button.click, prompt.submit],
481
  fn=infer,
482
  inputs=[
483
- prompt,
484
- negative_prompt,
485
  seed,
486
  randomize_seed,
487
- width,
488
- height,
489
  guidance_scale,
490
  num_inference_steps,
 
491
  ],
492
- outputs=[result, seed],
 
493
  )
494
 
495
  if __name__ == "__main__":
496
- demo.launch()
 
1
+ import gradio as gr
2
 
3
+ from absl import flags
4
+ from absl import app
5
+ from ml_collections import config_flags
6
+ import os
7
 
8
+ import spaces #[uncomment to use ZeroGPU]
9
+ import torch
10
 
11
 
12
+ import os
13
+ import random
14
 
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torchvision.utils import save_image
19
+ from huggingface_hub import hf_hub_download
20
 
21
+ from absl import logging
22
+ import ml_collections
23
 
24
+ from diffusion.flow_matching import ODEEulerFlowMatchingSolver
25
+ import utils
26
+ import libs.autoencoder
27
+ from libs.clip import FrozenCLIPEmbedder
28
+ from configs import t2i_512px_clip_dimr
29
 
30
 
31
+ def unpreprocess(x: torch.Tensor) -> torch.Tensor:
32
+ x = 0.5 * (x + 1.0)
33
+ x.clamp_(0.0, 1.0)
34
+ return x
35
 
36
+ def cosine_similarity_torch(latent1: torch.Tensor, latent2: torch.Tensor) -> torch.Tensor:
37
+ latent1_flat = latent1.view(-1)
38
+ latent2_flat = latent2.view(-1)
39
+ cosine_similarity = F.cosine_similarity(
40
+ latent1_flat.unsqueeze(0), latent2_flat.unsqueeze(0), dim=1
41
+ )
42
+ return cosine_similarity
43
+
44
+ def kl_divergence(latent1: torch.Tensor, latent2: torch.Tensor) -> torch.Tensor:
45
+ latent1_prob = F.softmax(latent1, dim=-1)
46
+ latent2_prob = F.softmax(latent2, dim=-1)
47
+ latent1_log_prob = torch.log(latent1_prob)
48
+ kl_div = F.kl_div(latent1_log_prob, latent2_prob, reduction="batchmean")
49
+ return kl_div
50
+
51
+ def batch_decode(_z: torch.Tensor, decode, batch_size: int = 10) -> torch.Tensor:
52
+ num_samples = _z.size(0)
53
+ decoded_batches = []
54
+
55
+ for i in range(0, num_samples, batch_size):
56
+ batch = _z[i : i + batch_size]
57
+ decoded_batch = decode(batch)
58
+ decoded_batches.append(decoded_batch)
59
+
60
+ return torch.cat(decoded_batches, dim=0)
61
+
62
+ def get_caption(llm: str, text_model, prompt_dict: dict, batch_size: int):
63
+ if batch_size == 3:
64
+ # Only addition or only subtraction mode.
65
+ assert len(prompt_dict) == 2, "Expected 2 prompts for batch_size 3."
66
+ batch_prompts = list(prompt_dict.values()) + [" "]
67
+ elif batch_size == 4:
68
+ # Addition and subtraction mode.
69
+ assert len(prompt_dict) == 3, "Expected 3 prompts for batch_size 4."
70
+ batch_prompts = list(prompt_dict.values()) + [" "]
71
+ elif batch_size >= 5:
72
+ # Linear interpolation mode.
73
+ assert len(prompt_dict) == 2, "Expected 2 prompts for linear interpolation."
74
+ batch_prompts = [prompt_dict["prompt_1"]] + [" "] * (batch_size - 2) + [prompt_dict["prompt_2"]]
75
+ else:
76
+ raise ValueError(f"Unsupported batch_size: {batch_size}")
77
+
78
+ if llm == "clip":
79
+ latent, latent_and_others = text_model.encode(batch_prompts)
80
+ context = latent_and_others["token_embedding"].detach()
81
+ elif llm == "t5":
82
+ latent, latent_and_others = text_model.get_text_embeddings(batch_prompts)
83
+ context = (latent_and_others["token_embedding"] * 10.0).detach()
84
+ else:
85
+ raise NotImplementedError(f"Language model {llm} not supported.")
86
+
87
+ token_mask = latent_and_others["token_mask"].detach()
88
+ tokens = latent_and_others["tokens"].detach()
89
+ captions = batch_prompts
90
+
91
+ return context, token_mask, tokens, captions
92
+
93
+ # Load configuration and initialize models.
94
+ config_dict = t2i_512px_clip_dimr.get_config()
95
+ config = ml_collections.ConfigDict(config_dict)
96
+
97
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
98
+ logging.info(f"Using device: {device}")
99
+
100
+ # Freeze configuration.
101
+ config = ml_collections.FrozenConfigDict(config)
102
+
103
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
104
+ MAX_SEED = np.iinfo(np.int32).max
105
+ MAX_IMAGE_SIZE = 1024 # Currently not used.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ # Load the main diffusion model.
108
+ repo_id = "QHL067/CrossFlow"
109
+ filename = "pretrained_models/t2i_512px_clip_dimr.pth"
110
+ checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
111
+ nnet = utils.get_nnet(**config.nnet)
112
+ nnet = nnet.to(device)
113
+ state_dict = torch.load(checkpoint_path, map_location=device)
114
+ nnet.load_state_dict(state_dict)
115
+ nnet.eval()
116
 
117
+ # Initialize text model.
118
+ llm = "clip"
119
+ clip = FrozenCLIPEmbedder()
120
+ clip.eval()
121
+ clip.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ # Load autoencoder.
124
+ autoencoder = libs.autoencoder.get_model(**config.autoencoder)
125
+ autoencoder.to(device)
126
 
 
 
 
127
 
128
+ @torch.cuda.amp.autocast()
129
+ def encode(_batch: torch.Tensor) -> torch.Tensor:
130
+ """Encode a batch of images using the autoencoder."""
131
+ return autoencoder.encode(_batch)
132
 
 
 
 
 
133
 
134
+ @torch.cuda.amp.autocast()
135
+ def decode(_batch: torch.Tensor) -> torch.Tensor:
136
+ """Decode a batch of latent vectors using the autoencoder."""
137
+ return autoencoder.decode(_batch)
 
138
 
139
 
140
+ @spaces.GPU #[uncomment to use ZeroGPU]
141
  def infer(
142
+ prompt1,
143
+ prompt2,
144
  seed,
145
  randomize_seed,
 
 
146
  guidance_scale,
147
  num_inference_steps,
148
+ num_of_interpolation,
149
+ save_gpu_memory=True,
150
  progress=gr.Progress(track_tqdm=True),
151
  ):
152
  if randomize_seed:
153
  seed = random.randint(0, MAX_SEED)
154
 
155
+ torch.manual_seed(seed)
156
+ if device.type == "cuda":
157
+ torch.cuda.manual_seed_all(seed)
158
 
159
+ # Only support interpolation in this implementation.
160
+ prompt_dict = {"prompt_1": prompt1, "prompt_2": prompt2}
161
+ for key, value in prompt_dict.items():
162
+ assert value is not None, f"{key} must not be None."
163
+ assert num_of_interpolation >= 5, "For linear interpolation, please sample at least five images."
 
 
 
 
164
 
165
+ # Get text embeddings and tokens.
166
+ _context, _token_mask, _token, _caption = get_caption(
167
+ llm, clip, prompt_dict=prompt_dict, batch_size=num_of_interpolation
168
+ )
169
 
170
+ with torch.no_grad():
171
+ _z_gaussian = torch.randn(num_of_interpolation, *config.z_shape, device=device)
172
+ _z_x0, _mu, _log_var = nnet(
173
+ _context, text_encoder=True, shape=_z_gaussian.shape, mask=_token_mask
174
+ )
175
+ _z_init = _z_x0.reshape(_z_gaussian.shape)
176
+
177
+ # Prepare the initial latent representations based on the number of interpolations.
178
+ if num_of_interpolation == 3:
179
+ # Addition or subtraction mode.
180
+ if config.prompt_a is not None:
181
+ assert config.prompt_s is None, "Only one of prompt_a or prompt_s should be provided."
182
+ z_init_temp = _z_init[0] + _z_init[1]
183
+ elif config.prompt_s is not None:
184
+ assert config.prompt_a is None, "Only one of prompt_a or prompt_s should be provided."
185
+ z_init_temp = _z_init[0] - _z_init[1]
186
+ else:
187
+ raise NotImplementedError("Either prompt_a or prompt_s must be provided for 3-sample mode.")
188
+ mean = z_init_temp.mean()
189
+ std = z_init_temp.std()
190
+ _z_init[2] = (z_init_temp - mean) / std
191
+
192
+ elif num_of_interpolation == 4:
193
+ z_init_temp = _z_init[0] + _z_init[1] - _z_init[2]
194
+ mean = z_init_temp.mean()
195
+ std = z_init_temp.std()
196
+ _z_init[3] = (z_init_temp - mean) / std
197
+
198
+ elif num_of_interpolation >= 5:
199
+ tensor_a = _z_init[0]
200
+ tensor_b = _z_init[-1]
201
+ num_interpolations = num_of_interpolation - 2
202
+ interpolations = [
203
+ tensor_a + (tensor_b - tensor_a) * (i / (num_interpolations + 1))
204
+ for i in range(1, num_interpolations + 1)
205
+ ]
206
+ _z_init = torch.stack([tensor_a] + interpolations + [tensor_b], dim=0)
207
+
208
+ else:
209
+ raise ValueError("Unsupported number of interpolations.")
210
+
211
+ assert guidance_scale > 1, "Guidance scale must be greater than 1."
212
+
213
+ has_null_indicator = hasattr(config.nnet.model_args, "cfg_indicator")
214
+ ode_solver = ODEEulerFlowMatchingSolver(
215
+ nnet,
216
+ bdv_model_fn=None,
217
+ step_size_type="step_in_dsigma",
218
+ guidance_scale=guidance_scale,
219
+ )
220
+ _z, _ = ode_solver.sample(
221
+ x_T=_z_init,
222
+ batch_size=num_of_interpolation,
223
+ sample_steps=num_inference_steps,
224
+ unconditional_guidance_scale=guidance_scale,
225
+ has_null_indicator=has_null_indicator,
226
+ )
227
+
228
+ if save_gpu_memory:
229
+ image_unprocessed = batch_decode(_z, decode)
230
+ else:
231
+ image_unprocessed = decode(_z)
232
+
233
+ samples = unpreprocess(image_unprocessed).contiguous()[0]
234
+
235
+ # return samples, seed
236
+ return seed
237
 
238
 
239
+ # examples = [
240
+ # "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
241
+ # "An astronaut riding a green horse",
242
+ # "A delicious ceviche cheesecake slice",
243
+ # ]
244
+
245
  examples = [
246
+ ["A dog cooking dinner in the kitchen", "An orange cat wearing sunglasses on a ship"],
 
 
247
  ]
248
 
249
  css = """
 
255
 
256
  with gr.Blocks(css=css) as demo:
257
  with gr.Column(elem_id="col-container"):
258
+ gr.Markdown(" # CrossFlow")
259
+ gr.Markdown(" CrossFlow directly transforms text representations into images for text-to-image generation, enabling interpolation in the input text latent space.")
260
 
261
  with gr.Row():
262
+ prompt1 = gr.Text(
263
+ label="Prompt_1",
264
  show_label=False,
265
  max_lines=1,
266
+ placeholder="Enter your prompt for the first image",
267
+ container=False,
268
+ )
269
+
270
+ with gr.Row():
271
+ prompt2 = gr.Text(
272
+ label="Prompt_2",
273
+ show_label=False,
274
+ max_lines=1,
275
+ placeholder="Enter your prompt for the second image",
276
  container=False,
277
  )
278
 
279
+ with gr.Row():
280
  run_button = gr.Button("Run", scale=0, variant="primary")
281
 
282
  result = gr.Image(label="Result", show_label=False)
283
 
284
  with gr.Accordion("Advanced Settings", open=False):
 
 
 
 
 
 
 
285
  seed = gr.Slider(
286
  label="Seed",
287
  minimum=0,
 
292
 
293
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  with gr.Row():
296
  guidance_scale = gr.Slider(
297
  label="Guidance scale",
298
  minimum=0.0,
299
  maximum=10.0,
300
  step=0.1,
301
+ value=7.0, # Replace with defaults that work for your model
302
  )
303
+ with gr.Row():
304
  num_inference_steps = gr.Slider(
305
  label="Number of inference steps",
306
  minimum=1,
307
  maximum=50,
308
  step=1,
309
+ value=50, # Replace with defaults that work for your model
310
+ )
311
+ with gr.Row():
312
+ num_of_interpolation = gr.Slider(
313
+ label="Number of images for interpolation",
314
+ minimum=5,
315
+ maximum=50,
316
+ step=1,
317
+ value=10, # Replace with defaults that work for your model
318
  )
319
 
320
+ gr.Examples(examples=examples, inputs=[prompt1, prompt2])
321
  gr.on(
322
+ triggers=[run_button.click, prompt1.submit, prompt2.submit],
323
  fn=infer,
324
  inputs=[
325
+ prompt1,
326
+ prompt2,
327
  seed,
328
  randomize_seed,
 
 
329
  guidance_scale,
330
  num_inference_steps,
331
+ num_of_interpolation,
332
  ],
333
+ # outputs=[result, seed],
334
+ outputs=[seed],
335
  )
336
 
337
  if __name__ == "__main__":
338
+ demo.launch()