helblazer811 commited on
Commit
d6a75b9
·
1 Parent(s): 48daf06

Changes to UI

Browse files
Files changed (2) hide show
  1. ConceptAttentionCallout.svg +4 -0
  2. app.py +162 -176
ConceptAttentionCallout.svg ADDED
app.py CHANGED
@@ -12,24 +12,6 @@ from concept_attention import ConceptAttentionFluxPipeline
12
  IMG_SIZE = 210
13
  COLUMNS = 5
14
 
15
- EXAMPLES = [
16
- [
17
- "A dog by a tree", # prompt
18
- "tree, dog, grass, background", # words
19
- 42, # seed
20
- ],
21
- # [
22
- # "A dragon", # prompt
23
- # "dragon, sky, rock, cloud", # words
24
- # 42, # seed
25
- # ],
26
- # [
27
- # "A hot air balloon", # prompt
28
- # "balloon, sky, water, tree", # words
29
- # 42, # seed
30
- # ]
31
- ]
32
-
33
  def update_default_concepts(prompt):
34
  default_concepts = {
35
  "A dog by a tree": ["dog", "grass", "tree", "background"],
@@ -39,7 +21,7 @@ def update_default_concepts(prompt):
39
 
40
  return gr.update(value=default_concepts.get(prompt, []))
41
 
42
- pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda") #, offload_model=True)
43
 
44
  def convert_pil_to_bytes(img):
45
  img = img.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST)
@@ -51,21 +33,14 @@ def convert_pil_to_bytes(img):
51
 
52
  @spaces.GPU(duration=60)
53
  def process_inputs(prompt, concepts, seed, layer_start_index, timestep_start_index):
54
- # print("Processing inputs")
55
- # assert layer_start_index is not None
56
- # assert timestep_start_index is not None
57
 
58
  if not prompt.strip():
59
  raise gr.exceptions.InputError("prompt", "Please enter a prompt")
60
 
61
  prompt = prompt.strip()
62
 
63
- print(concepts)
64
- # if not word_list.strip():
65
- # gr.exceptions.InputError("words", "Please enter comma-separated words")
66
-
67
- # concepts = [w.strip() for w in word_list.split(",")]
68
-
69
  if len(concepts) == 0:
70
  raise gr.exceptions.InputError("words", "Please enter at least 1 concept")
71
 
@@ -94,18 +69,17 @@ def process_inputs(prompt, concepts, seed, layer_start_index, timestep_start_ind
94
  cross_attention_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in cross_attention_heatmaps]
95
  cross_attention_maps_and_labels = [(cross_attention_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))]
96
 
97
- # heatmaps_and_labels = [(concept_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))]
98
- # all_images_and_labels = [(output_image, "Generated Image")] + heatmaps_and_labels
99
- # num_rows = math.ceil(len(all_images_and_labels) / COLUMNS)
100
-
101
  return output_image, \
102
  gr.update(value=output_space_maps_and_labels, columns=len(output_space_maps_and_labels)), \
103
  gr.update(value=cross_attention_maps_and_labels, columns=len(cross_attention_maps_and_labels))
104
 
105
  with gr.Blocks(
106
  css="""
107
- .container { max-width: 1200px; margin: 0 auto; padding: 20px; }
108
- .title { text-align: center; margin-bottom: 10px; }
 
 
 
109
  .authors { text-align: center; margin-bottom: 10px; }
110
  .affiliations { text-align: center; color: #666; margin-bottom: 10px; }
111
  .abstract { text-align: center; margin-bottom: 40px; }
@@ -115,6 +89,10 @@ with gr.Blocks(
115
  justify-content: center;
116
  height: 100%; /* Ensures full height */
117
  }
 
 
 
 
118
  .input {
119
  height: 47px;
120
  }
@@ -123,159 +101,167 @@ with gr.Blocks(
123
  gap: 0px;
124
  }
125
  .input-column-label {}
126
- .gallery {
127
- # scrollbar-width: thin;
128
- # scrollbar-color: #27272A;
129
- }
130
-
131
  .run-button-column {
132
  width: 100px !important;
133
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  """
135
  # ,
136
  # elem_classes="container"
137
  ) as demo:
138
- with gr.Column(elem_classes="container"):
139
- gr.Markdown("# ConceptAttention: Visualize Any Concepts in Your Generated Images", elem_classes="title")
140
- # gr.Markdown("### Alec Helbling¹, Tuna Meral², Ben Hoover¹³, Pinar Yanardag², Duen Horng (Polo) Chau¹", elem_classes="authors")
141
- # gr.Markdown("### ¹Georgia Tech · ²Virginia Tech · ³IBM Research", elem_classes="affiliations")
142
- gr.Markdown("## Interpret generative models with precise, high-quality heatmaps. Check out our paper [here](https://arxiv.org/abs/2502.04320).", elem_classes="abstract")
143
-
144
- with gr.Row(scale=1, equal_height=True):
145
- with gr.Column(scale=3, elem_classes="input-column"):
146
- gr.HTML(
147
- "Write a Prompt",
148
- elem_classes="input-column-label"
149
- )
150
- prompt = gr.Dropdown(
151
- ["A dog by a tree", "A dragon", "A hot air balloon"],
152
- # label="Prompt",
153
- container=False,
154
- # scale=3,
155
- allow_custom_value=True,
156
- elem_classes="input"
157
- )
158
-
159
- with gr.Column(scale=7, elem_classes="input-column"):
160
- gr.HTML(
161
- "Select or Write Concepts",
162
- elem_classes="input-column-label"
163
- )
164
- concepts = gr.Dropdown(
165
- ["dog", "grass", "tree", "dragon", "sky", "rock", "cloud", "balloon", "water", "background"],
166
- value=["dog", "grass", "tree", "background"],
167
- multiselect=True,
168
- label="Concepts",
169
- container=False,
170
- allow_custom_value=True,
171
- # scale=4,
172
- elem_classes="input",
173
- max_choices=5
174
- )
175
-
176
- with gr.Column(scale=1, min_width=100, elem_classes="input-column run-button-column"):
177
- gr.HTML(
178
- "​",
179
- elem_classes="input-column-label"
180
- )
181
- submit_btn = gr.Button(
182
- "Run",
183
- # scale=1,
184
- elem_classes="input"
185
- )
186
- # prompt = gr.Textbox(
187
- # label="Enter your prompt",
188
- # placeholder="Enter your prompt",
189
- # value=EXAMPLES[0][0],
190
- # scale=4,
191
- # # show_label=True,
192
- # container=False
193
- # # height="80px"
194
- # )
195
- # words = gr.Textbox(
196
- # label="Enter a list of concepts (comma-separated)",
197
- # placeholder="Enter a list of concepts (comma-separated)",
198
- # value=EXAMPLES[0][1],
199
- # scale=4,
200
- # # show_label=True,
201
- # container=False
202
- # # height="80px"
203
- # )
204
-
205
- num_rows_state = gr.State(value=1) # Initial number of rows
206
-
207
- # generated_image = gr.Image(label="Generated Image", elem_classes="input-image")
208
- # gallery = gr.Gallery(
209
- # label="Generated images",
210
- # show_label=True,
211
- # # elem_id="gallery",
212
- # columns=COLUMNS,
213
- # rows=1,
214
- # # object_fit="contain",
215
- # height="auto",
216
- # elem_classes="gallery"
217
- # )
218
-
219
- with gr.Row(elem_classes="gallery", scale=8):
220
-
221
- with gr.Column(scale=1):
222
- generated_image = gr.Image(
223
- elem_classes="generated-image",
224
- show_label=False
225
- )
226
-
227
- with gr.Column(scale=4):
228
- concept_attention_gallery = gr.Gallery(
229
- label="Concept Attention (Ours)",
230
- show_label=True,
231
- # columns=3,
232
- rows=1,
233
- object_fit="contain",
234
- height="200px",
235
- elem_classes="gallery"
236
- )
237
-
238
- cross_attention_gallery = gr.Gallery(
239
- label="Cross Attention",
240
- show_label=True,
241
- # columns=3,
242
- rows=1,
243
- object_fit="contain",
244
- height="200px",
245
- elem_classes="gallery"
246
- )
247
-
248
- with gr.Accordion("Advanced Settings", open=False):
249
- seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42)
250
- layer_start_index = gr.Slider(minimum=0, maximum=18, step=1, label="Layer Start Index", value=10)
251
- timestep_start_index = gr.Slider(minimum=0, maximum=4, step=1, label="Timestep Start Index", value=2)
252
-
253
- submit_btn.click(
254
- fn=process_inputs,
255
- inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index],
256
- outputs=[generated_image, concept_attention_gallery, cross_attention_gallery]
257
- )
258
-
259
- # gr.Examples(examples=EXAMPLES, inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index], outputs=[gallery, num_rows_state], fn=process_inputs, cache_examples=False)
260
- # num_rows_state.change(
261
- # fn=lambda rows: gr.Gallery.update(rows=int(rows)),
262
- # inputs=[num_rows_state],
263
- # outputs=[gallery]
264
- # )
265
-
266
- prompt.change(update_default_concepts, inputs=[prompt], outputs=[concepts])
267
-
268
- # Automatically process the first example on launch
269
- demo.load(
270
- process_inputs,
271
- inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index],
272
- outputs=[generated_image, concept_attention_gallery, cross_attention_gallery]
273
- )
274
 
275
  if __name__ == "__main__":
276
  if os.path.exists("/data-nvme/zerogpu-offload"):
277
  subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
278
- demo.launch()
 
 
279
  # share=True,
280
  # server_name="0.0.0.0",
281
  # inbrowser=True,
 
12
  IMG_SIZE = 210
13
  COLUMNS = 5
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def update_default_concepts(prompt):
16
  default_concepts = {
17
  "A dog by a tree": ["dog", "grass", "tree", "background"],
 
21
 
22
  return gr.update(value=default_concepts.get(prompt, []))
23
 
24
+ pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda:2", offload_model=True)
25
 
26
  def convert_pil_to_bytes(img):
27
  img = img.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST)
 
33
 
34
  @spaces.GPU(duration=60)
35
  def process_inputs(prompt, concepts, seed, layer_start_index, timestep_start_index):
36
+ if not prompt:
37
+ raise gr.exceptions.InputError("prompt", "Please enter a prompt")
 
38
 
39
  if not prompt.strip():
40
  raise gr.exceptions.InputError("prompt", "Please enter a prompt")
41
 
42
  prompt = prompt.strip()
43
 
 
 
 
 
 
 
44
  if len(concepts) == 0:
45
  raise gr.exceptions.InputError("words", "Please enter at least 1 concept")
46
 
 
69
  cross_attention_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in cross_attention_heatmaps]
70
  cross_attention_maps_and_labels = [(cross_attention_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))]
71
 
 
 
 
 
72
  return output_image, \
73
  gr.update(value=output_space_maps_and_labels, columns=len(output_space_maps_and_labels)), \
74
  gr.update(value=cross_attention_maps_and_labels, columns=len(cross_attention_maps_and_labels))
75
 
76
  with gr.Blocks(
77
  css="""
78
+ .container {
79
+ max-width: 1400px;
80
+ margin: 0 auto;
81
+ padding: 20px;
82
+ }
83
  .authors { text-align: center; margin-bottom: 10px; }
84
  .affiliations { text-align: center; color: #666; margin-bottom: 10px; }
85
  .abstract { text-align: center; margin-bottom: 40px; }
 
89
  justify-content: center;
90
  height: 100%; /* Ensures full height */
91
  }
92
+ .header {
93
+ display: flex;
94
+ flex-direction: column;
95
+ }
96
  .input {
97
  height: 47px;
98
  }
 
101
  gap: 0px;
102
  }
103
  .input-column-label {}
104
+ .gallery {}
 
 
 
 
105
  .run-button-column {
106
  width: 100px !important;
107
  }
108
+ #title {
109
+ font-size: 2.4em;
110
+ text-align: center;
111
+ margin-bottom: 10px;
112
+ }
113
+ #subtitle {
114
+ font-size: 2.0em;
115
+ text-align: center;
116
+ }
117
+
118
+ #concept-attention-callout-svg {
119
+ width: 250px;
120
+ }
121
+
122
+ /* Show only on screens wider than 768px (adjust as needed) */
123
+ @media (min-width: 1024px) {
124
+ .svg-container {
125
+ min-width: 150px;
126
+ width: 200px;
127
+ padding-top: 540px;
128
+ }
129
+ }
130
+
131
+ @media (min-width: 1280px) {
132
+ .svg-container {
133
+ min-width: 200px;
134
+ width: 300px;
135
+ padding-top: 400px;
136
+ }
137
+ }
138
+ @media (min-width: 1530px) {
139
+ .svg-container {
140
+ min-width: 200px;
141
+ width: 300px;
142
+ padding-top: 370px;
143
+ }
144
+ }
145
+
146
+
147
+ @media (max-width: 1024px) {
148
+ .svg-container {
149
+ display: none;
150
+ }
151
+ }
152
+
153
  """
154
  # ,
155
  # elem_classes="container"
156
  ) as demo:
157
+ with gr.Row(elem_classes="container"):
158
+ with gr.Column(elem_classes="application", scale=15):
159
+ with gr.Row(scale=3, elem_classes="header"):
160
+ gr.HTML("<h1 id='title'> ConceptAttention: Visualize Any Concepts in Your Generated Images</h1>")
161
+ gr.HTML("<h2 id='subtitle'> Interpret generative models with precise, high-quality heatmaps. <br/> Check out our paper <a href='https://arxiv.org/abs/2502.04320'> here </a>. </h2>")
162
+
163
+ with gr.Row(scale=1, equal_height=True):
164
+ with gr.Column(scale=4, elem_classes="input-column", min_width=250):
165
+ gr.HTML(
166
+ "Write a Prompt",
167
+ elem_classes="input-column-label"
168
+ )
169
+ prompt = gr.Dropdown(
170
+ ["A dog by a tree", "A dragon", "A hot air balloon"],
171
+ container=False,
172
+ allow_custom_value=True,
173
+ elem_classes="input"
174
+ )
175
+
176
+ with gr.Column(scale=7, elem_classes="input-column"):
177
+ gr.HTML(
178
+ "Select or Write Concepts",
179
+ elem_classes="input-column-label"
180
+ )
181
+ concepts = gr.Dropdown(
182
+ ["dog", "grass", "tree", "dragon", "sky", "rock", "cloud", "balloon", "water", "background"],
183
+ value=["dog", "grass", "tree", "background"],
184
+ multiselect=True,
185
+ label="Concepts",
186
+ container=False,
187
+ allow_custom_value=True,
188
+ # scale=4,
189
+ elem_classes="input",
190
+ max_choices=5
191
+ )
192
+
193
+ with gr.Column(scale=1, min_width=100, elem_classes="input-column run-button-column"):
194
+ gr.HTML(
195
+ "&#8203;",
196
+ elem_classes="input-column-label"
197
+ )
198
+ submit_btn = gr.Button(
199
+ "Run",
200
+ elem_classes="input"
201
+ )
202
+
203
+ with gr.Row(elem_classes="gallery", scale=8):
204
+
205
+ with gr.Column(scale=1, min_width=250):
206
+ generated_image = gr.Image(
207
+ elem_classes="generated-image",
208
+ show_label=False
209
+ )
210
+
211
+ with gr.Column(scale=4):
212
+ concept_attention_gallery = gr.Gallery(
213
+ label="Concept Attention (Ours)",
214
+ show_label=True,
215
+ # columns=3,
216
+ rows=1,
217
+ object_fit="contain",
218
+ height="200px",
219
+ elem_classes="gallery",
220
+ elem_id="concept-attention-gallery"
221
+ )
222
+
223
+ cross_attention_gallery = gr.Gallery(
224
+ label="Cross Attention",
225
+ show_label=True,
226
+ # columns=3,
227
+ rows=1,
228
+ object_fit="contain",
229
+ height="200px",
230
+ elem_classes="gallery"
231
+ )
232
+
233
+ with gr.Accordion("Advanced Settings", open=False):
234
+ seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42)
235
+ layer_start_index = gr.Slider(minimum=0, maximum=18, step=1, label="Layer Start Index", value=10)
236
+ timestep_start_index = gr.Slider(minimum=0, maximum=4, step=1, label="Timestep Start Index", value=2)
237
+
238
+ submit_btn.click(
239
+ fn=process_inputs,
240
+ inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index],
241
+ outputs=[generated_image, concept_attention_gallery, cross_attention_gallery]
242
+ )
243
+
244
+ prompt.change(update_default_concepts, inputs=[prompt], outputs=[concepts])
245
+
246
+ # Automatically process the first example on launch
247
+ demo.load(
248
+ process_inputs,
249
+ inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index],
250
+ outputs=[generated_image, concept_attention_gallery, cross_attention_gallery]
251
+ )
252
+
253
+ with gr.Column(scale=4, min_width=250, elem_classes="svg-container"):
254
+ concept_attention_callout_svg = gr.HTML(
255
+ "<img src='/gradio_api/file=ConceptAttentionCallout.svg' id='concept-attention-callout-svg'/>",
256
+ # container=False,
257
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
  if __name__ == "__main__":
260
  if os.path.exists("/data-nvme/zerogpu-offload"):
261
  subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
262
+ demo.launch(
263
+ allowed_paths=["."]
264
+ )
265
  # share=True,
266
  # server_name="0.0.0.0",
267
  # inbrowser=True,