Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
bde9560
1
Parent(s):
54f7994
Added a second UI for uploading images. However, there are currently
Browse files
app.py
CHANGED
@@ -21,7 +21,7 @@ def update_default_concepts(prompt):
|
|
21 |
|
22 |
return gr.update(value=default_concepts.get(prompt, []))
|
23 |
|
24 |
-
pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell"
|
25 |
|
26 |
def convert_pil_to_bytes(img):
|
27 |
img = img.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST)
|
@@ -32,7 +32,53 @@ def convert_pil_to_bytes(img):
|
|
32 |
return img_str
|
33 |
|
34 |
@spaces.GPU(duration=60)
|
35 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
try:
|
37 |
if not prompt:
|
38 |
raise gr.Error("Please enter a prompt", duration=10)
|
@@ -212,18 +258,19 @@ with gr.Blocks(
|
|
212 |
|
213 |
# with gr.Column(elem_classes="container"):
|
214 |
|
|
|
|
|
|
|
215 |
|
216 |
-
|
217 |
-
|
218 |
-
|
|
|
|
|
|
|
|
|
219 |
|
220 |
-
|
221 |
-
gr.HTML("""
|
222 |
-
<h1 id='title'> ConceptAttention </h1>
|
223 |
-
<h1 id='subtitle'> Visualize Any Concepts in Your Generated Images </h1>
|
224 |
-
<h1 id='abstract'> Interpret diffusion models with precise, high-quality heatmaps. </h1>
|
225 |
-
<h1 id='links'> <a href='https://arxiv.org/abs/2502.04320'> Paper </a> | <a href='https://github.com/helblazer811/ConceptAttention'> Code </a> </h1>
|
226 |
-
""")
|
227 |
|
228 |
with gr.Row(elem_classes="input-row", scale=2):
|
229 |
with gr.Column(scale=4, elem_classes="input-column", min_width=250):
|
@@ -231,6 +278,7 @@ with gr.Blocks(
|
|
231 |
"Write a Prompt",
|
232 |
elem_classes="input-column-label"
|
233 |
)
|
|
|
234 |
prompt = gr.Dropdown(
|
235 |
["A dog by a tree", "A man on the beach", "A hot air balloon"],
|
236 |
container=False,
|
@@ -303,7 +351,7 @@ with gr.Blocks(
|
|
303 |
timestep_start_index = gr.Slider(minimum=0, maximum=4, step=1, label="Timestep Start Index", value=2)
|
304 |
|
305 |
submit_btn.click(
|
306 |
-
fn=
|
307 |
inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index],
|
308 |
outputs=[generated_image, concept_attention_gallery, cross_attention_gallery]
|
309 |
)
|
@@ -312,24 +360,126 @@ with gr.Blocks(
|
|
312 |
|
313 |
# Automatically process the first example on launch
|
314 |
demo.load(
|
315 |
-
|
316 |
inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index],
|
317 |
outputs=[generated_image, concept_attention_gallery, cross_attention_gallery]
|
318 |
)
|
319 |
|
320 |
-
with gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
|
322 |
-
|
323 |
-
|
|
|
|
|
|
|
|
|
324 |
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
|
331 |
-
|
332 |
-
gr.HTML("<div></div>")
|
333 |
|
334 |
if __name__ == "__main__":
|
335 |
if os.path.exists("/data-nvme/zerogpu-offload"):
|
|
|
21 |
|
22 |
return gr.update(value=default_concepts.get(prompt, []))
|
23 |
|
24 |
+
pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", offload_model=True) # , device="cuda:0") # , offload_model=True)
|
25 |
|
26 |
def convert_pil_to_bytes(img):
|
27 |
img = img.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST)
|
|
|
32 |
return img_str
|
33 |
|
34 |
@spaces.GPU(duration=60)
|
35 |
+
def encode_image(image, prompt, concepts, seed, layer_start_index, noise_timestep, num_samples):
|
36 |
+
try:
|
37 |
+
if not prompt:
|
38 |
+
prompt = ""
|
39 |
+
|
40 |
+
prompt = prompt.strip()
|
41 |
+
|
42 |
+
if len(concepts) == 0:
|
43 |
+
raise gr.Error("Please enter at least 1 concept", duration=10)
|
44 |
+
|
45 |
+
if len(concepts) > 9:
|
46 |
+
raise gr.Error("Please enter at most 9 concepts", duration=10)
|
47 |
+
|
48 |
+
pipeline_output = pipeline.encode_image(
|
49 |
+
image=image,
|
50 |
+
prompt=prompt,
|
51 |
+
concepts=concepts,
|
52 |
+
width=1024,
|
53 |
+
height=1024,
|
54 |
+
seed=seed,
|
55 |
+
num_samples=num_samples,
|
56 |
+
noise_timestep=noise_timestep,
|
57 |
+
num_steps=4,
|
58 |
+
layer_indices=list(range(layer_start_index, 19)),
|
59 |
+
softmax=True if len(concepts) > 1 else False
|
60 |
+
)
|
61 |
+
|
62 |
+
output_image = pipeline_output.image
|
63 |
+
|
64 |
+
output_space_heatmaps = pipeline_output.concept_heatmaps
|
65 |
+
output_space_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in output_space_heatmaps]
|
66 |
+
output_space_maps_and_labels = [(output_space_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))]
|
67 |
+
|
68 |
+
cross_attention_heatmaps = pipeline_output.cross_attention_maps
|
69 |
+
cross_attention_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in cross_attention_heatmaps]
|
70 |
+
cross_attention_maps_and_labels = [(cross_attention_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))]
|
71 |
+
|
72 |
+
return output_image, \
|
73 |
+
gr.update(value=output_space_maps_and_labels, columns=len(output_space_maps_and_labels)), \
|
74 |
+
gr.update(value=cross_attention_maps_and_labels, columns=len(cross_attention_maps_and_labels))
|
75 |
+
|
76 |
+
except gr.Error as e:
|
77 |
+
return None, gr.update(value=[], columns=1), gr.update(value=[], columns=1)
|
78 |
+
|
79 |
+
|
80 |
+
@spaces.GPU(duration=60)
|
81 |
+
def generate_image(prompt, concepts, seed, layer_start_index, timestep_start_index):
|
82 |
try:
|
83 |
if not prompt:
|
84 |
raise gr.Error("Please enter a prompt", duration=10)
|
|
|
258 |
|
259 |
# with gr.Column(elem_classes="container"):
|
260 |
|
261 |
+
with gr.Row(elem_classes="container", scale=8):
|
262 |
+
|
263 |
+
with gr.Column(elem_classes="application-content", scale=10):
|
264 |
|
265 |
+
with gr.Row(scale=3, elem_classes="header"):
|
266 |
+
gr.HTML("""
|
267 |
+
<h1 id='title'> ConceptAttention </h1>
|
268 |
+
<h1 id='subtitle'> Visualize Any Concepts in Your Generated Images </h1>
|
269 |
+
<h1 id='abstract'> Interpret diffusion models with precise, high-quality heatmaps. </h1>
|
270 |
+
<h1 id='links'> <a href='https://arxiv.org/abs/2502.04320'> Paper </a> | <a href='https://github.com/helblazer811/ConceptAttention'> Code </a> </h1>
|
271 |
+
""")
|
272 |
|
273 |
+
with gr.Tab(label="Generate Image"):
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
|
275 |
with gr.Row(elem_classes="input-row", scale=2):
|
276 |
with gr.Column(scale=4, elem_classes="input-column", min_width=250):
|
|
|
278 |
"Write a Prompt",
|
279 |
elem_classes="input-column-label"
|
280 |
)
|
281 |
+
|
282 |
prompt = gr.Dropdown(
|
283 |
["A dog by a tree", "A man on the beach", "A hot air balloon"],
|
284 |
container=False,
|
|
|
351 |
timestep_start_index = gr.Slider(minimum=0, maximum=4, step=1, label="Timestep Start Index", value=2)
|
352 |
|
353 |
submit_btn.click(
|
354 |
+
fn=generate_image,
|
355 |
inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index],
|
356 |
outputs=[generated_image, concept_attention_gallery, cross_attention_gallery]
|
357 |
)
|
|
|
360 |
|
361 |
# Automatically process the first example on launch
|
362 |
demo.load(
|
363 |
+
generate_image,
|
364 |
inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index],
|
365 |
outputs=[generated_image, concept_attention_gallery, cross_attention_gallery]
|
366 |
)
|
367 |
|
368 |
+
with gr.Tab(label="Explain a Real Image"):
|
369 |
+
|
370 |
+
with gr.Row(elem_classes="input-row", scale=2):
|
371 |
+
with gr.Column(scale=4, elem_classes="input-column", min_width=250):
|
372 |
+
gr.HTML(
|
373 |
+
"Write a Prompt (Optional)",
|
374 |
+
elem_classes="input-column-label"
|
375 |
+
)
|
376 |
+
# prompt = gr.Dropdown(
|
377 |
+
# ["A dog by a tree", "A man on the beach", "A hot air balloon"],
|
378 |
+
# container=False,
|
379 |
+
# allow_custom_value=True,
|
380 |
+
# elem_classes="input"
|
381 |
+
# )
|
382 |
+
|
383 |
+
prompt = gr.Textbox(
|
384 |
+
placeholder="Write a prompt (Optional)",
|
385 |
+
container=False,
|
386 |
+
elem_classes="input"
|
387 |
+
)
|
388 |
+
|
389 |
+
with gr.Column(scale=7, elem_classes="input-column"):
|
390 |
+
gr.HTML(
|
391 |
+
"Select or Write Concepts",
|
392 |
+
elem_classes="input-column-label"
|
393 |
+
)
|
394 |
+
concepts = gr.Dropdown(
|
395 |
+
["dog", "grass", "tree", "dragon", "sky", "rock", "cloud", "balloon", "water", "background"],
|
396 |
+
value=["dog", "grass", "tree", "background"],
|
397 |
+
multiselect=True,
|
398 |
+
label="Concepts",
|
399 |
+
container=False,
|
400 |
+
allow_custom_value=True,
|
401 |
+
# scale=4,
|
402 |
+
elem_classes="input",
|
403 |
+
max_choices=5
|
404 |
+
)
|
405 |
+
|
406 |
+
with gr.Column(scale=1, min_width=100, elem_classes="input-column run-button-column"):
|
407 |
+
gr.HTML(
|
408 |
+
"​",
|
409 |
+
elem_classes="input-column-label"
|
410 |
+
)
|
411 |
+
submit_btn = gr.Button(
|
412 |
+
"Run",
|
413 |
+
elem_classes="input"
|
414 |
+
)
|
415 |
+
|
416 |
+
with gr.Row(elem_classes="gallery-container", scale=8):
|
417 |
+
|
418 |
+
with gr.Column(scale=1, min_width=250):
|
419 |
+
input_image = gr.Image(
|
420 |
+
elem_classes="generated-image",
|
421 |
+
show_label=False,
|
422 |
+
interactive=True
|
423 |
+
)
|
424 |
+
|
425 |
+
with gr.Column(scale=4):
|
426 |
+
concept_attention_gallery = gr.Gallery(
|
427 |
+
label="Concept Attention (Ours)",
|
428 |
+
show_label=True,
|
429 |
+
# columns=3,
|
430 |
+
rows=1,
|
431 |
+
object_fit="contain",
|
432 |
+
height="200px",
|
433 |
+
elem_classes="gallery",
|
434 |
+
elem_id="concept-attention-gallery",
|
435 |
+
# scale=4
|
436 |
+
)
|
437 |
+
|
438 |
+
cross_attention_gallery = gr.Gallery(
|
439 |
+
label="Cross Attention",
|
440 |
+
show_label=True,
|
441 |
+
# columns=3,
|
442 |
+
rows=1,
|
443 |
+
object_fit="contain",
|
444 |
+
height="200px",
|
445 |
+
elem_classes="gallery",
|
446 |
+
# scale=4
|
447 |
+
)
|
448 |
+
|
449 |
+
with gr.Accordion("Advanced Settings", open=False):
|
450 |
+
seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42)
|
451 |
+
num_samples = gr.Slider(minimum=1, maximum=10, step=1, label="Number of Samples", value=4)
|
452 |
+
layer_start_index = gr.Slider(minimum=0, maximum=18, step=1, label="Layer Start Index", value=10)
|
453 |
+
noise_timestep = gr.Slider(minimum=0, maximum=4, step=1, label="Noise Timestep", value=2)
|
454 |
+
|
455 |
+
submit_btn.click(
|
456 |
+
fn=encode_image,
|
457 |
+
inputs=[input_image, prompt, concepts, seed, layer_start_index, noise_timestep, num_samples],
|
458 |
+
outputs=[input_image, concept_attention_gallery, cross_attention_gallery]
|
459 |
+
)
|
460 |
|
461 |
+
# # Automatically process the first example on launch
|
462 |
+
# demo.load(
|
463 |
+
# encode_image,
|
464 |
+
# inputs=[input_image, prompt, concepts, seed, layer_start_index, noise_timestep, num_samples],
|
465 |
+
# outputs=[input_image, concept_attention_gallery, cross_attention_gallery]
|
466 |
+
# )
|
467 |
|
468 |
+
with gr.Column(scale=2, min_width=200, elem_classes="svg-column"):
|
469 |
+
|
470 |
+
with gr.Row(scale=8):
|
471 |
+
gr.HTML("<div></div>")
|
472 |
+
|
473 |
+
with gr.Row(scale=4, elem_classes="svg-container"):
|
474 |
+
concept_attention_callout_svg = gr.HTML(
|
475 |
+
"<img src='/gradio_api/file=ConceptAttentionCallout.svg' id='concept-attention-callout-svg'/>",
|
476 |
+
# container=False,
|
477 |
+
)
|
478 |
+
|
479 |
+
with gr.Row(scale=4):
|
480 |
+
gr.HTML("<div></div>")
|
481 |
|
482 |
+
|
|
|
483 |
|
484 |
if __name__ == "__main__":
|
485 |
if os.path.exists("/data-nvme/zerogpu-offload"):
|