Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
4f83196
1
Parent(s):
e6d4da8
App changes
Browse files
app.py
CHANGED
@@ -1,44 +1,26 @@
|
|
1 |
import base64
|
2 |
import io
|
3 |
-
|
4 |
import spaces
|
5 |
import gradio as gr
|
6 |
from PIL import Image
|
7 |
-
import requests
|
8 |
-
import numpy as np
|
9 |
-
import PIL
|
10 |
|
11 |
from concept_attention import ConceptAttentionFluxPipeline
|
12 |
|
13 |
-
# concept_attention_default_args = {
|
14 |
-
# "model_name": "flux-schnell",
|
15 |
-
# "device": "cuda",
|
16 |
-
# "layer_indices": list(range(10, 19)),
|
17 |
-
# "timesteps": list(range(2, 4)),
|
18 |
-
# "num_samples": 4,
|
19 |
-
# "num_inference_steps": 4
|
20 |
-
# }
|
21 |
IMG_SIZE = 250
|
22 |
|
23 |
-
def download_image(url):
|
24 |
-
return Image.open(io.BytesIO(requests.get(url).content))
|
25 |
-
|
26 |
EXAMPLES = [
|
27 |
[
|
28 |
"A dog by a tree", # prompt
|
29 |
-
download_image("https://github.com/helblazer811/ConceptAttention/blob/master/images/dog_by_tree.png?raw=true"),
|
30 |
"tree, dog, grass, background", # words
|
31 |
42, # seed
|
32 |
],
|
33 |
[
|
34 |
"A dragon", # prompt
|
35 |
-
download_image("https://github.com/helblazer811/ConceptAttention/blob/master/images/dragon_image.png?raw=true"),
|
36 |
"dragon, sky, rock, cloud", # words
|
37 |
42, # seed
|
38 |
],
|
39 |
-
|
40 |
"A hot air balloon", # prompt
|
41 |
-
download_image("https://github.com/helblazer811/ConceptAttention/blob/master/images/hot_air_balloon.png?raw=true"),
|
42 |
"balloon, sky, water, tree", # words
|
43 |
42, # seed
|
44 |
]
|
@@ -47,67 +29,68 @@ EXAMPLES = [
|
|
47 |
pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda")
|
48 |
|
49 |
@spaces.GPU(duration=60)
|
50 |
-
def process_inputs(prompt,
|
51 |
print("Processing inputs")
|
|
|
|
|
|
|
52 |
prompt = prompt.strip()
|
53 |
if not word_list.strip():
|
54 |
-
|
55 |
|
56 |
concepts = [w.strip() for w in word_list.split(",")]
|
57 |
|
58 |
-
if
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
layer_indices=list(range(layer_start_index, 19)),
|
76 |
-
)
|
77 |
-
|
78 |
-
else:
|
79 |
-
pipeline_output = pipeline.generate_image(
|
80 |
-
prompt=prompt,
|
81 |
-
concepts=concepts,
|
82 |
-
width=1024,
|
83 |
-
height=1024,
|
84 |
-
seed=seed,
|
85 |
-
timesteps=list(range(timestep_start_index, 4)),
|
86 |
-
num_inference_steps=4,
|
87 |
-
layer_indices=list(range(layer_start_index, 19)),
|
88 |
-
)
|
89 |
|
90 |
output_image = pipeline_output.image
|
91 |
concept_heatmaps = pipeline_output.concept_heatmaps
|
|
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
img = heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST)
|
96 |
-
buffered = io.BytesIO()
|
97 |
-
img.save(buffered, format="PNG")
|
98 |
-
img_str = base64.b64encode(buffered.getvalue()).decode()
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
with gr.Blocks(
|
113 |
css="""
|
@@ -115,10 +98,8 @@ with gr.Blocks(
|
|
115 |
.title { text-align: center; margin-bottom: 10px; }
|
116 |
.authors { text-align: center; margin-bottom: 10px; }
|
117 |
.affiliations { text-align: center; color: #666; margin-bottom: 10px; }
|
118 |
-
.content { display: grid; grid-template-columns: 1fr 1fr; gap: 20px; }
|
119 |
-
.section { }
|
120 |
-
.input-image { width: 100%; height: 200px; }
|
121 |
.abstract { text-align: center; margin-bottom: 40px; }
|
|
|
122 |
"""
|
123 |
) as demo:
|
124 |
with gr.Column(elem_classes="container"):
|
@@ -134,41 +115,54 @@ with gr.Blocks(
|
|
134 |
elem_classes="abstract"
|
135 |
)
|
136 |
|
137 |
-
with gr.Row(elem_classes="
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
submit_btn.click(
|
162 |
fn=process_inputs,
|
163 |
-
inputs=[prompt,
|
|
|
164 |
)
|
165 |
-
# .then(
|
166 |
-
# fn=lambda component: gr.update(value=None),
|
167 |
-
# inputs=[image_input],
|
168 |
-
# outputs=[]
|
169 |
-
# )
|
170 |
|
171 |
-
gr.Examples(examples=EXAMPLES, inputs=[prompt,
|
|
|
|
|
|
|
|
|
172 |
|
173 |
if __name__ == "__main__":
|
174 |
demo.launch(max_threads=1)
|
|
|
1 |
import base64
|
2 |
import io
|
|
|
3 |
import spaces
|
4 |
import gradio as gr
|
5 |
from PIL import Image
|
|
|
|
|
|
|
6 |
|
7 |
from concept_attention import ConceptAttentionFluxPipeline
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
IMG_SIZE = 250
|
10 |
|
|
|
|
|
|
|
11 |
EXAMPLES = [
|
12 |
[
|
13 |
"A dog by a tree", # prompt
|
|
|
14 |
"tree, dog, grass, background", # words
|
15 |
42, # seed
|
16 |
],
|
17 |
[
|
18 |
"A dragon", # prompt
|
|
|
19 |
"dragon, sky, rock, cloud", # words
|
20 |
42, # seed
|
21 |
],
|
22 |
+
[
|
23 |
"A hot air balloon", # prompt
|
|
|
24 |
"balloon, sky, water, tree", # words
|
25 |
42, # seed
|
26 |
]
|
|
|
29 |
pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda")
|
30 |
|
31 |
@spaces.GPU(duration=60)
|
32 |
+
def process_inputs(prompt, word_list, seed, layer_start_index, timestep_start_index):
|
33 |
print("Processing inputs")
|
34 |
+
assert layer_start_index is not None
|
35 |
+
assert timestep_start_index is not None
|
36 |
+
|
37 |
prompt = prompt.strip()
|
38 |
if not word_list.strip():
|
39 |
+
gr.exceptions.InputError("words", "Please enter comma-separated words")
|
40 |
|
41 |
concepts = [w.strip() for w in word_list.split(",")]
|
42 |
|
43 |
+
if len(concepts) == 0:
|
44 |
+
raise gr.exceptions.InputError("words", "Please enter at least 1 concept")
|
45 |
+
|
46 |
+
if len(concepts) > 9:
|
47 |
+
raise gr.exceptions.InputError("words", "Please enter at most 9 concepts")
|
48 |
+
|
49 |
+
pipeline_output = pipeline.generate_image(
|
50 |
+
prompt=prompt,
|
51 |
+
concepts=concepts,
|
52 |
+
width=1024,
|
53 |
+
height=1024,
|
54 |
+
seed=seed,
|
55 |
+
timesteps=list(range(timestep_start_index, 4)),
|
56 |
+
num_inference_steps=4,
|
57 |
+
layer_indices=list(range(layer_start_index, 19)),
|
58 |
+
softmax=True if len(concepts) > 1 else False
|
59 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
output_image = pipeline_output.image
|
62 |
concept_heatmaps = pipeline_output.concept_heatmaps
|
63 |
+
concept_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in concept_heatmaps]
|
64 |
|
65 |
+
heatmaps_and_labels = [(concept_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))]
|
66 |
+
all_images_and_labels = [(output_image, "Generated Image")] + heatmaps_and_labels
|
|
|
|
|
|
|
|
|
67 |
|
68 |
+
# combined_html = "<div style='display: flex; flex-wrap: wrap; justify-content: center;'>"
|
69 |
+
# # Show the output image
|
70 |
+
# combined_html += f"""
|
71 |
+
# <div style='text-align: center; margin: 5px; padding: 5px;'>
|
72 |
+
# <img src='data:image/png;base64,{output_image}' style='width: {IMG_SIZE}px; display: inline-block; height: {IMG_SIZE}px;'>
|
73 |
+
# </div>
|
74 |
+
# """
|
75 |
|
76 |
+
# for concept, heatmap in zip(concepts, concept_heatmaps):
|
77 |
+
# img = heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST)
|
78 |
+
# buffered = io.BytesIO()
|
79 |
+
# img.save(buffered, format="PNG")
|
80 |
+
# img_str = base64.b64encode(buffered.getvalue()).decode()
|
81 |
|
82 |
+
# html = f"""
|
83 |
+
# <div style='text-align: center; margin: 5px; padding: 5px; overflow-x: auto; white-space: nowrap;'>
|
84 |
+
# <h1 style='margin-bottom: 10px;'>{concept}</h1>
|
85 |
+
# <img src='data:image/png;base64,{img_str}' style='width: {IMG_SIZE}px; display: inline-block; height: {IMG_SIZE}px;'>
|
86 |
+
# </div>
|
87 |
+
# """
|
88 |
+
|
89 |
+
# combined_html += html
|
90 |
+
|
91 |
+
# combined_html += "</div>"
|
92 |
+
|
93 |
+
return all_images_and_labels
|
94 |
|
95 |
with gr.Blocks(
|
96 |
css="""
|
|
|
98 |
.title { text-align: center; margin-bottom: 10px; }
|
99 |
.authors { text-align: center; margin-bottom: 10px; }
|
100 |
.affiliations { text-align: center; color: #666; margin-bottom: 10px; }
|
|
|
|
|
|
|
101 |
.abstract { text-align: center; margin-bottom: 40px; }
|
102 |
+
.input-row { height: 60px; }
|
103 |
"""
|
104 |
) as demo:
|
105 |
with gr.Column(elem_classes="container"):
|
|
|
115 |
elem_classes="abstract"
|
116 |
)
|
117 |
|
118 |
+
with gr.Row(equal_height=True, elem_classes="input-row"):
|
119 |
+
prompt = gr.Textbox(
|
120 |
+
label="Enter your prompt",
|
121 |
+
placeholder="Enter your prompt",
|
122 |
+
value=EXAMPLES[0][0],
|
123 |
+
scale=4,
|
124 |
+
# show_label=False
|
125 |
+
)
|
126 |
+
words = gr.Textbox(
|
127 |
+
label="Enter a list of concepts (comma-separated)",
|
128 |
+
placeholder="Enter a list of concepts (comma-separated)",
|
129 |
+
value=EXAMPLES[0][1],
|
130 |
+
scale=4,
|
131 |
+
# show_label=False
|
132 |
+
)
|
133 |
+
submit_btn = gr.Button(
|
134 |
+
"Run",
|
135 |
+
min_width="100px",
|
136 |
+
scale=1
|
137 |
+
)
|
138 |
+
|
139 |
+
# generated_image = gr.Image(label="Generated Image", elem_classes="input-image")
|
140 |
+
gallery = gr.Gallery(
|
141 |
+
label="Generated images",
|
142 |
+
show_label=True,
|
143 |
+
elem_id="gallery",
|
144 |
+
columns=[5],
|
145 |
+
# rows=[1],
|
146 |
+
object_fit="contain",
|
147 |
+
# height="auto"
|
148 |
+
)
|
149 |
+
with gr.Accordion("Advanced Settings", open=False):
|
150 |
+
seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42)
|
151 |
+
layer_start_index = gr.Slider(minimum=0, maximum=18, step=1, label="Layer Start Index", value=10)
|
152 |
+
timestep_start_index = gr.Slider(minimum=0, maximum=4, step=1, label="Timestep Start Index", value=2)
|
153 |
+
|
154 |
|
155 |
submit_btn.click(
|
156 |
fn=process_inputs,
|
157 |
+
inputs=[prompt, words, seed, layer_start_index, timestep_start_index],
|
158 |
+
outputs=[gallery]
|
159 |
)
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
+
gr.Examples(examples=EXAMPLES, inputs=[prompt, words, seed, layer_start_index, timestep_start_index], outputs=[gallery], fn=process_inputs, cache_examples=True)
|
162 |
+
|
163 |
+
# Automatically process the first example on launch
|
164 |
+
demo.load(process_inputs, inputs=[prompt, words, seed, layer_start_index, timestep_start_index], outputs=[gallery])
|
165 |
+
|
166 |
|
167 |
if __name__ == "__main__":
|
168 |
demo.launch(max_threads=1)
|