Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,193 Bytes
55866f4 4b30dce 55866f4 5f8123c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import base64
import io
import spaces
import gradio as gr
from PIL import Image
from concept_attention import ConceptAttentionFluxPipeline
concept_attention_default_args = {
"model_name": "flux-schnell",
"device": "cuda",
"layer_indices": list(range(10, 19)),
"timesteps": list(range(4)),
"num_samples": 4,
"num_inference_steps": 4
}
IMG_SIZE = 250
EXAMPLES = [
[
"A fluffy cat sitting on a windowsill", # prompt
"cat.jpg", # image
"fur, whiskers, eyes", # words
42, # seed
],
# ["Mountain landscape with lake", "cat.jpg", "sky, trees, water", 123],
# ["Portrait of a young woman", "monkey.png", "face, hair, eyes", 456],
]
pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda")
@spaces.GPU(duration=60)
def process_inputs(prompt, input_image, word_list, seed):
print("Processing inputs")
prompt = prompt.strip()
if not word_list.strip():
return None, "Please enter comma-separated words"
concepts = [w.strip() for w in word_list.split(",")]
if input_image is not None:
input_image = Image.fromarray(input_image)
input_image = input_image.convert("RGB")
input_image = input_image.resize((1024, 1024))
pipeline_output = pipeline.encode_image(
image=input_image,
concepts=concepts,
prompt=prompt,
width=1024,
height=1024,
seed=seed,
num_samples=concept_attention_default_args["num_samples"]
)
else:
pipeline_output = pipeline.generate_image(
prompt=prompt,
concepts=concepts,
width=1024,
height=1024,
seed=seed,
timesteps=concept_attention_default_args["timesteps"],
num_inference_steps=concept_attention_default_args["num_inference_steps"],
)
output_image = pipeline_output.image
concept_heatmaps = pipeline_output.concept_heatmaps
html_elements = []
for concept, heatmap in zip(concepts, concept_heatmaps):
img = heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST)
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
html = f"""
<div style='text-align: center; margin: 5px; padding: 5px; overflow-x: auto; white-space: nowrap;'>
<h1 style='margin-bottom: 10px;'>{concept}</h1>
<img src='data:image/png;base64,{img_str}' style='width: {IMG_SIZE}px; display: inline-block; height: {IMG_SIZE}px;'>
</div>
"""
html_elements.append(html)
combined_html = "<div style='display: flex; flex-wrap: wrap; justify-content: center;'>" + "".join(html_elements) + "</div>"
return output_image, combined_html
with gr.Blocks(
css="""
.container { max-width: 1200px; margin: 0 auto; padding: 20px; }
.title { text-align: center; margin-bottom: 10px; }
.authors { text-align: center; margin-bottom: 20px; }
.affiliations { text-align: center; color: #666; margin-bottom: 40px; }
.content { display: grid; grid-template-columns: 1fr 1fr; gap: 20px; }
.section { border: 2px solid #ddd; border-radius: 10px; padding: 20px; }
"""
) as demo:
with gr.Column(elem_classes="container"):
gr.Markdown("# ConceptAttention: Diffusion Transformers Learn Highly Interpretable Features", elem_classes="title")
gr.Markdown("**Alec Helbling**¹, **Tuna Meral**², **Ben Hoover**¹³, **Pinar Yanardag**², **Duen Horng (Polo) Chau**¹", elem_classes="authors")
gr.Markdown("¹Georgia Tech · ²Virginia Tech · ³IBM Research", elem_classes="affiliations")
with gr.Row(elem_classes="content"):
with gr.Column(elem_classes="section"):
gr.Markdown("### Input")
prompt = gr.Textbox(label="Enter your prompt")
words = gr.Textbox(label="Enter words (comma-separated)")
seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42)
gr.HTML("<div style='text-align: center;'> <h1> Or </h1> </div>")
image_input = gr.Image(type="numpy", label="Upload image (optional)")
with gr.Column(elem_classes="section"):
gr.Markdown("### Output")
output_image = gr.Image(type="numpy", label="Output image")
with gr.Row():
submit_btn = gr.Button("Process")
with gr.Row(elem_classes="section"):
saliency_display = gr.HTML(label="Saliency Maps")
submit_btn.click(
fn=process_inputs,
inputs=[prompt, image_input, words, seed], outputs=[output_image, saliency_display]
)
gr.Examples(examples=EXAMPLES, inputs=[prompt, image_input, words, seed], outputs=[output_image, saliency_display], fn=process_inputs, cache_examples=False)
if __name__ == "__main__":
demo.launch()
# share=True,
# server_name="0.0.0.0",
# inbrowser=True,
# # share=False,
# server_port=6754,
# quiet=True,
# max_threads=1
# )
|