helblazer811 commited on
Commit
bde9560
·
1 Parent(s): 54f7994

Added a second UI for uploading images. However, there are currently

Browse files
Files changed (1) hide show
  1. app.py +174 -24
app.py CHANGED
@@ -21,7 +21,7 @@ def update_default_concepts(prompt):
21
 
22
  return gr.update(value=default_concepts.get(prompt, []))
23
 
24
- pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell") # , offload_model=True) # , device="cuda:2", offload_model=True)
25
 
26
  def convert_pil_to_bytes(img):
27
  img = img.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST)
@@ -32,7 +32,53 @@ def convert_pil_to_bytes(img):
32
  return img_str
33
 
34
  @spaces.GPU(duration=60)
35
- def process_inputs(prompt, concepts, seed, layer_start_index, timestep_start_index):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  try:
37
  if not prompt:
38
  raise gr.Error("Please enter a prompt", duration=10)
@@ -212,18 +258,19 @@ with gr.Blocks(
212
 
213
  # with gr.Column(elem_classes="container"):
214
 
 
 
 
215
 
216
- with gr.Row(elem_classes="container", scale=8):
217
-
218
- with gr.Column(elem_classes="application-content", scale=10):
 
 
 
 
219
 
220
- with gr.Row(scale=3, elem_classes="header"):
221
- gr.HTML("""
222
- <h1 id='title'> ConceptAttention </h1>
223
- <h1 id='subtitle'> Visualize Any Concepts in Your Generated Images </h1>
224
- <h1 id='abstract'> Interpret diffusion models with precise, high-quality heatmaps. </h1>
225
- <h1 id='links'> <a href='https://arxiv.org/abs/2502.04320'> Paper </a> | <a href='https://github.com/helblazer811/ConceptAttention'> Code </a> </h1>
226
- """)
227
 
228
  with gr.Row(elem_classes="input-row", scale=2):
229
  with gr.Column(scale=4, elem_classes="input-column", min_width=250):
@@ -231,6 +278,7 @@ with gr.Blocks(
231
  "Write a Prompt",
232
  elem_classes="input-column-label"
233
  )
 
234
  prompt = gr.Dropdown(
235
  ["A dog by a tree", "A man on the beach", "A hot air balloon"],
236
  container=False,
@@ -303,7 +351,7 @@ with gr.Blocks(
303
  timestep_start_index = gr.Slider(minimum=0, maximum=4, step=1, label="Timestep Start Index", value=2)
304
 
305
  submit_btn.click(
306
- fn=process_inputs,
307
  inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index],
308
  outputs=[generated_image, concept_attention_gallery, cross_attention_gallery]
309
  )
@@ -312,24 +360,126 @@ with gr.Blocks(
312
 
313
  # Automatically process the first example on launch
314
  demo.load(
315
- process_inputs,
316
  inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index],
317
  outputs=[generated_image, concept_attention_gallery, cross_attention_gallery]
318
  )
319
 
320
- with gr.Column(scale=2, min_width=200, elem_classes="svg-column"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
- with gr.Row(scale=8):
323
- gr.HTML("<div></div>")
 
 
 
 
324
 
325
- with gr.Row(scale=4, elem_classes="svg-container"):
326
- concept_attention_callout_svg = gr.HTML(
327
- "<img src='/gradio_api/file=ConceptAttentionCallout.svg' id='concept-attention-callout-svg'/>",
328
- # container=False,
329
- )
 
 
 
 
 
 
 
 
330
 
331
- with gr.Row(scale=4):
332
- gr.HTML("<div></div>")
333
 
334
  if __name__ == "__main__":
335
  if os.path.exists("/data-nvme/zerogpu-offload"):
 
21
 
22
  return gr.update(value=default_concepts.get(prompt, []))
23
 
24
+ pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", offload_model=True) # , device="cuda:0") # , offload_model=True)
25
 
26
  def convert_pil_to_bytes(img):
27
  img = img.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST)
 
32
  return img_str
33
 
34
  @spaces.GPU(duration=60)
35
+ def encode_image(image, prompt, concepts, seed, layer_start_index, noise_timestep, num_samples):
36
+ try:
37
+ if not prompt:
38
+ prompt = ""
39
+
40
+ prompt = prompt.strip()
41
+
42
+ if len(concepts) == 0:
43
+ raise gr.Error("Please enter at least 1 concept", duration=10)
44
+
45
+ if len(concepts) > 9:
46
+ raise gr.Error("Please enter at most 9 concepts", duration=10)
47
+
48
+ pipeline_output = pipeline.encode_image(
49
+ image=image,
50
+ prompt=prompt,
51
+ concepts=concepts,
52
+ width=1024,
53
+ height=1024,
54
+ seed=seed,
55
+ num_samples=num_samples,
56
+ noise_timestep=noise_timestep,
57
+ num_steps=4,
58
+ layer_indices=list(range(layer_start_index, 19)),
59
+ softmax=True if len(concepts) > 1 else False
60
+ )
61
+
62
+ output_image = pipeline_output.image
63
+
64
+ output_space_heatmaps = pipeline_output.concept_heatmaps
65
+ output_space_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in output_space_heatmaps]
66
+ output_space_maps_and_labels = [(output_space_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))]
67
+
68
+ cross_attention_heatmaps = pipeline_output.cross_attention_maps
69
+ cross_attention_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in cross_attention_heatmaps]
70
+ cross_attention_maps_and_labels = [(cross_attention_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))]
71
+
72
+ return output_image, \
73
+ gr.update(value=output_space_maps_and_labels, columns=len(output_space_maps_and_labels)), \
74
+ gr.update(value=cross_attention_maps_and_labels, columns=len(cross_attention_maps_and_labels))
75
+
76
+ except gr.Error as e:
77
+ return None, gr.update(value=[], columns=1), gr.update(value=[], columns=1)
78
+
79
+
80
+ @spaces.GPU(duration=60)
81
+ def generate_image(prompt, concepts, seed, layer_start_index, timestep_start_index):
82
  try:
83
  if not prompt:
84
  raise gr.Error("Please enter a prompt", duration=10)
 
258
 
259
  # with gr.Column(elem_classes="container"):
260
 
261
+ with gr.Row(elem_classes="container", scale=8):
262
+
263
+ with gr.Column(elem_classes="application-content", scale=10):
264
 
265
+ with gr.Row(scale=3, elem_classes="header"):
266
+ gr.HTML("""
267
+ <h1 id='title'> ConceptAttention </h1>
268
+ <h1 id='subtitle'> Visualize Any Concepts in Your Generated Images </h1>
269
+ <h1 id='abstract'> Interpret diffusion models with precise, high-quality heatmaps. </h1>
270
+ <h1 id='links'> <a href='https://arxiv.org/abs/2502.04320'> Paper </a> | <a href='https://github.com/helblazer811/ConceptAttention'> Code </a> </h1>
271
+ """)
272
 
273
+ with gr.Tab(label="Generate Image"):
 
 
 
 
 
 
274
 
275
  with gr.Row(elem_classes="input-row", scale=2):
276
  with gr.Column(scale=4, elem_classes="input-column", min_width=250):
 
278
  "Write a Prompt",
279
  elem_classes="input-column-label"
280
  )
281
+
282
  prompt = gr.Dropdown(
283
  ["A dog by a tree", "A man on the beach", "A hot air balloon"],
284
  container=False,
 
351
  timestep_start_index = gr.Slider(minimum=0, maximum=4, step=1, label="Timestep Start Index", value=2)
352
 
353
  submit_btn.click(
354
+ fn=generate_image,
355
  inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index],
356
  outputs=[generated_image, concept_attention_gallery, cross_attention_gallery]
357
  )
 
360
 
361
  # Automatically process the first example on launch
362
  demo.load(
363
+ generate_image,
364
  inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index],
365
  outputs=[generated_image, concept_attention_gallery, cross_attention_gallery]
366
  )
367
 
368
+ with gr.Tab(label="Explain a Real Image"):
369
+
370
+ with gr.Row(elem_classes="input-row", scale=2):
371
+ with gr.Column(scale=4, elem_classes="input-column", min_width=250):
372
+ gr.HTML(
373
+ "Write a Prompt (Optional)",
374
+ elem_classes="input-column-label"
375
+ )
376
+ # prompt = gr.Dropdown(
377
+ # ["A dog by a tree", "A man on the beach", "A hot air balloon"],
378
+ # container=False,
379
+ # allow_custom_value=True,
380
+ # elem_classes="input"
381
+ # )
382
+
383
+ prompt = gr.Textbox(
384
+ placeholder="Write a prompt (Optional)",
385
+ container=False,
386
+ elem_classes="input"
387
+ )
388
+
389
+ with gr.Column(scale=7, elem_classes="input-column"):
390
+ gr.HTML(
391
+ "Select or Write Concepts",
392
+ elem_classes="input-column-label"
393
+ )
394
+ concepts = gr.Dropdown(
395
+ ["dog", "grass", "tree", "dragon", "sky", "rock", "cloud", "balloon", "water", "background"],
396
+ value=["dog", "grass", "tree", "background"],
397
+ multiselect=True,
398
+ label="Concepts",
399
+ container=False,
400
+ allow_custom_value=True,
401
+ # scale=4,
402
+ elem_classes="input",
403
+ max_choices=5
404
+ )
405
+
406
+ with gr.Column(scale=1, min_width=100, elem_classes="input-column run-button-column"):
407
+ gr.HTML(
408
+ "&#8203;",
409
+ elem_classes="input-column-label"
410
+ )
411
+ submit_btn = gr.Button(
412
+ "Run",
413
+ elem_classes="input"
414
+ )
415
+
416
+ with gr.Row(elem_classes="gallery-container", scale=8):
417
+
418
+ with gr.Column(scale=1, min_width=250):
419
+ input_image = gr.Image(
420
+ elem_classes="generated-image",
421
+ show_label=False,
422
+ interactive=True
423
+ )
424
+
425
+ with gr.Column(scale=4):
426
+ concept_attention_gallery = gr.Gallery(
427
+ label="Concept Attention (Ours)",
428
+ show_label=True,
429
+ # columns=3,
430
+ rows=1,
431
+ object_fit="contain",
432
+ height="200px",
433
+ elem_classes="gallery",
434
+ elem_id="concept-attention-gallery",
435
+ # scale=4
436
+ )
437
+
438
+ cross_attention_gallery = gr.Gallery(
439
+ label="Cross Attention",
440
+ show_label=True,
441
+ # columns=3,
442
+ rows=1,
443
+ object_fit="contain",
444
+ height="200px",
445
+ elem_classes="gallery",
446
+ # scale=4
447
+ )
448
+
449
+ with gr.Accordion("Advanced Settings", open=False):
450
+ seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42)
451
+ num_samples = gr.Slider(minimum=1, maximum=10, step=1, label="Number of Samples", value=4)
452
+ layer_start_index = gr.Slider(minimum=0, maximum=18, step=1, label="Layer Start Index", value=10)
453
+ noise_timestep = gr.Slider(minimum=0, maximum=4, step=1, label="Noise Timestep", value=2)
454
+
455
+ submit_btn.click(
456
+ fn=encode_image,
457
+ inputs=[input_image, prompt, concepts, seed, layer_start_index, noise_timestep, num_samples],
458
+ outputs=[input_image, concept_attention_gallery, cross_attention_gallery]
459
+ )
460
 
461
+ # # Automatically process the first example on launch
462
+ # demo.load(
463
+ # encode_image,
464
+ # inputs=[input_image, prompt, concepts, seed, layer_start_index, noise_timestep, num_samples],
465
+ # outputs=[input_image, concept_attention_gallery, cross_attention_gallery]
466
+ # )
467
 
468
+ with gr.Column(scale=2, min_width=200, elem_classes="svg-column"):
469
+
470
+ with gr.Row(scale=8):
471
+ gr.HTML("<div></div>")
472
+
473
+ with gr.Row(scale=4, elem_classes="svg-container"):
474
+ concept_attention_callout_svg = gr.HTML(
475
+ "<img src='/gradio_api/file=ConceptAttentionCallout.svg' id='concept-attention-callout-svg'/>",
476
+ # container=False,
477
+ )
478
+
479
+ with gr.Row(scale=4):
480
+ gr.HTML("<div></div>")
481
 
482
+
 
483
 
484
  if __name__ == "__main__":
485
  if os.path.exists("/data-nvme/zerogpu-offload"):