helblazer811's picture
Dynamically update gallery num rows.
ade193e
raw
history blame
5.92 kB
import spaces
import gradio as gr
from PIL import Image
import math
from concept_attention import ConceptAttentionFluxPipeline
IMG_SIZE = 250
COLUMNS = 5
EXAMPLES = [
[
"A dog by a tree", # prompt
"tree, dog, grass, background", # words
42, # seed
],
[
"A dragon", # prompt
"dragon, sky, rock, cloud", # words
42, # seed
],
[
"A hot air balloon", # prompt
"balloon, sky, water, tree", # words
42, # seed
]
]
pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda")
@spaces.GPU(duration=60)
def process_inputs(prompt, word_list, seed, layer_start_index, timestep_start_index):
print("Processing inputs")
assert layer_start_index is not None
assert timestep_start_index is not None
prompt = prompt.strip()
if not word_list.strip():
gr.exceptions.InputError("words", "Please enter comma-separated words")
concepts = [w.strip() for w in word_list.split(",")]
if len(concepts) == 0:
raise gr.exceptions.InputError("words", "Please enter at least 1 concept")
if len(concepts) > 9:
raise gr.exceptions.InputError("words", "Please enter at most 9 concepts")
pipeline_output = pipeline.generate_image(
prompt=prompt,
concepts=concepts,
width=1024,
height=1024,
seed=seed,
timesteps=list(range(timestep_start_index, 4)),
num_inference_steps=4,
layer_indices=list(range(layer_start_index, 19)),
softmax=True if len(concepts) > 1 else False
)
output_image = pipeline_output.image
concept_heatmaps = pipeline_output.concept_heatmaps
concept_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in concept_heatmaps]
heatmaps_and_labels = [(concept_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))]
all_images_and_labels = [(output_image, "Generated Image")] + heatmaps_and_labels
num_rows = math.ceil(len(all_images_and_labels) / COLUMNS)
gallery = gr.Gallery(
label="Generated images",
show_label=True,
columns=[COLUMNS],
rows=[num_rows],
object_fit="contain"
)
return gallery
with gr.Blocks(
css="""
.container { max-width: 1200px; margin: 0 auto; padding: 20px; }
.title { text-align: center; margin-bottom: 10px; }
.authors { text-align: center; margin-bottom: 10px; }
.affiliations { text-align: center; color: #666; margin-bottom: 10px; }
.abstract { text-align: center; margin-bottom: 40px; }
"""
) as demo:
with gr.Column(elem_classes="container"):
gr.Markdown("# ConceptAttention: Diffusion Transformers Learn Highly Interpretable Features", elem_classes="title")
gr.Markdown("### Alec Helbling¹, Tuna Meral², Ben Hoover¹³, Pinar Yanardag², Duen Horng (Polo) Chau¹", elem_classes="authors")
gr.Markdown("### ¹Georgia Tech · ²Virginia Tech · ³IBM Research", elem_classes="affiliations")
gr.Markdown(
"""
We introduce ConceptAttention, an approach to interpreting the intermediate representations of diffusion transformers.
The user just gives a list of textual concepts and ConceptAttention will produce a set of saliency maps depicting
the location and intensity of these concepts in generated images. Check out our paper: [here](https://arxiv.org/abs/2502.04320).
""",
elem_classes="abstract"
)
with gr.Row(scale=1):
prompt = gr.Textbox(
label="Enter your prompt",
placeholder="Enter your prompt",
value=EXAMPLES[0][0],
scale=4,
show_label=True,
container=False
# height="80px"
)
words = gr.Textbox(
label="Enter a list of concepts (comma-separated)",
placeholder="Enter a list of concepts (comma-separated)",
value=EXAMPLES[0][1],
scale=4,
show_label=True,
container=False
# height="80px"
)
submit_btn = gr.Button(
"Run",
min_width="100px",
scale=1
)
# generated_image = gr.Image(label="Generated Image", elem_classes="input-image")
gallery = gr.Gallery(
label="Generated images",
show_label=True,
# elem_id="gallery",
columns=[COLUMNS],
rows=[1],
object_fit="contain",
# height="auto"
)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42)
layer_start_index = gr.Slider(minimum=0, maximum=18, step=1, label="Layer Start Index", value=10)
timestep_start_index = gr.Slider(minimum=0, maximum=4, step=1, label="Timestep Start Index", value=2)
submit_btn.click(
fn=process_inputs,
inputs=[prompt, words, seed, layer_start_index, timestep_start_index],
outputs=[gallery]
)
gr.Examples(examples=EXAMPLES, inputs=[prompt, words, seed, layer_start_index, timestep_start_index], outputs=[gallery], fn=process_inputs, cache_examples=False)
# Automatically process the first example on launch
# demo.load(process_inputs, inputs=[prompt, words, seed, layer_start_index, timestep_start_index], outputs=[gallery])
if __name__ == "__main__":
demo.launch(max_threads=1)
# share=True,
# server_name="0.0.0.0",
# inbrowser=True,
# # share=False,
# server_port=6754,
# quiet=True,
# max_threads=1
# )