helblazer811 commited on
Commit
4f83196
·
1 Parent(s): e6d4da8

App changes

Browse files
Files changed (1) hide show
  1. app.py +94 -100
app.py CHANGED
@@ -1,44 +1,26 @@
1
  import base64
2
  import io
3
-
4
  import spaces
5
  import gradio as gr
6
  from PIL import Image
7
- import requests
8
- import numpy as np
9
- import PIL
10
 
11
  from concept_attention import ConceptAttentionFluxPipeline
12
 
13
- # concept_attention_default_args = {
14
- # "model_name": "flux-schnell",
15
- # "device": "cuda",
16
- # "layer_indices": list(range(10, 19)),
17
- # "timesteps": list(range(2, 4)),
18
- # "num_samples": 4,
19
- # "num_inference_steps": 4
20
- # }
21
  IMG_SIZE = 250
22
 
23
- def download_image(url):
24
- return Image.open(io.BytesIO(requests.get(url).content))
25
-
26
  EXAMPLES = [
27
  [
28
  "A dog by a tree", # prompt
29
- download_image("https://github.com/helblazer811/ConceptAttention/blob/master/images/dog_by_tree.png?raw=true"),
30
  "tree, dog, grass, background", # words
31
  42, # seed
32
  ],
33
  [
34
  "A dragon", # prompt
35
- download_image("https://github.com/helblazer811/ConceptAttention/blob/master/images/dragon_image.png?raw=true"),
36
  "dragon, sky, rock, cloud", # words
37
  42, # seed
38
  ],
39
- [
40
  "A hot air balloon", # prompt
41
- download_image("https://github.com/helblazer811/ConceptAttention/blob/master/images/hot_air_balloon.png?raw=true"),
42
  "balloon, sky, water, tree", # words
43
  42, # seed
44
  ]
@@ -47,67 +29,68 @@ EXAMPLES = [
47
  pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda")
48
 
49
  @spaces.GPU(duration=60)
50
- def process_inputs(prompt, input_image, word_list, seed, num_samples, layer_start_index, timestep_start_index):
51
  print("Processing inputs")
 
 
 
52
  prompt = prompt.strip()
53
  if not word_list.strip():
54
- return None, "Please enter comma-separated words"
55
 
56
  concepts = [w.strip() for w in word_list.split(",")]
57
 
58
- if input_image is not None:
59
- if isinstance(input_image, np.ndarray):
60
- input_image = Image.fromarray(input_image)
61
- input_image = input_image.convert("RGB")
62
- input_image = input_image.resize((1024, 1024))
63
- elif isinstance(input_image, PIL.Image.Image):
64
- input_image = input_image.convert("RGB")
65
- input_image = input_image.resize((1024, 1024))
66
-
67
- pipeline_output = pipeline.encode_image(
68
- image=input_image,
69
- concepts=concepts,
70
- prompt=prompt,
71
- width=1024,
72
- height=1024,
73
- seed=seed,
74
- num_samples=num_samples,
75
- layer_indices=list(range(layer_start_index, 19)),
76
- )
77
-
78
- else:
79
- pipeline_output = pipeline.generate_image(
80
- prompt=prompt,
81
- concepts=concepts,
82
- width=1024,
83
- height=1024,
84
- seed=seed,
85
- timesteps=list(range(timestep_start_index, 4)),
86
- num_inference_steps=4,
87
- layer_indices=list(range(layer_start_index, 19)),
88
- )
89
 
90
  output_image = pipeline_output.image
91
  concept_heatmaps = pipeline_output.concept_heatmaps
 
92
 
93
- html_elements = []
94
- for concept, heatmap in zip(concepts, concept_heatmaps):
95
- img = heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST)
96
- buffered = io.BytesIO()
97
- img.save(buffered, format="PNG")
98
- img_str = base64.b64encode(buffered.getvalue()).decode()
99
 
100
- html = f"""
101
- <div style='text-align: center; margin: 5px; padding: 5px; overflow-x: auto; white-space: nowrap;'>
102
- <h1 style='margin-bottom: 10px;'>{concept}</h1>
103
- <img src='data:image/png;base64,{img_str}' style='width: {IMG_SIZE}px; display: inline-block; height: {IMG_SIZE}px;'>
104
- </div>
105
- """
106
- html_elements.append(html)
107
 
108
- combined_html = "<div style='display: flex; flex-wrap: wrap; justify-content: center;'>" + "".join(html_elements) + "</div>"
109
- return output_image, combined_html, None # None fills input_image with None
 
 
 
110
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  with gr.Blocks(
113
  css="""
@@ -115,10 +98,8 @@ with gr.Blocks(
115
  .title { text-align: center; margin-bottom: 10px; }
116
  .authors { text-align: center; margin-bottom: 10px; }
117
  .affiliations { text-align: center; color: #666; margin-bottom: 10px; }
118
- .content { display: grid; grid-template-columns: 1fr 1fr; gap: 20px; }
119
- .section { }
120
- .input-image { width: 100%; height: 200px; }
121
  .abstract { text-align: center; margin-bottom: 40px; }
 
122
  """
123
  ) as demo:
124
  with gr.Column(elem_classes="container"):
@@ -134,41 +115,54 @@ with gr.Blocks(
134
  elem_classes="abstract"
135
  )
136
 
137
- with gr.Row(elem_classes="content"):
138
- with gr.Column(elem_classes="section"):
139
- gr.Markdown("### Input")
140
- prompt = gr.Textbox(label="Enter your prompt")
141
- words = gr.Textbox(label="Enter a list of concepts (comma-separated)")
142
- # gr.HTML("<div style='text-align: center;'> <h3> Or </h3> </div>")
143
- image_input = gr.Image(type="numpy", label="Upload image (optional)", elem_classes="input-image")
144
- # Set up advanced options
145
- with gr.Accordion("Advanced Settings", open=False):
146
- seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42)
147
- num_samples = gr.Slider(minimum=1, maximum=10, step=1, label="Number of Samples", value=4)
148
- layer_start_index = gr.Slider(minimum=0, maximum=18, step=1, label="Layer Start Index", value=10)
149
- timestep_start_index = gr.Slider(minimum=0, maximum=4, step=1, label="Timestep Start Index", value=2)
150
-
151
- with gr.Column(elem_classes="section"):
152
- gr.Markdown("### Output")
153
- output_image = gr.Image(type="numpy", label="Output image")
154
-
155
- with gr.Row():
156
- submit_btn = gr.Button("Process")
157
-
158
- with gr.Row(elem_classes="section"):
159
- saliency_display = gr.HTML(label="Saliency Maps")
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  submit_btn.click(
162
  fn=process_inputs,
163
- inputs=[prompt, image_input, words, seed, num_samples, layer_start_index, timestep_start_index], outputs=[output_image, saliency_display, image_input]
 
164
  )
165
- # .then(
166
- # fn=lambda component: gr.update(value=None),
167
- # inputs=[image_input],
168
- # outputs=[]
169
- # )
170
 
171
- gr.Examples(examples=EXAMPLES, inputs=[prompt, image_input, words, seed], outputs=[output_image, saliency_display], fn=process_inputs, cache_examples=False)
 
 
 
 
172
 
173
  if __name__ == "__main__":
174
  demo.launch(max_threads=1)
 
1
  import base64
2
  import io
 
3
  import spaces
4
  import gradio as gr
5
  from PIL import Image
 
 
 
6
 
7
  from concept_attention import ConceptAttentionFluxPipeline
8
 
 
 
 
 
 
 
 
 
9
  IMG_SIZE = 250
10
 
 
 
 
11
  EXAMPLES = [
12
  [
13
  "A dog by a tree", # prompt
 
14
  "tree, dog, grass, background", # words
15
  42, # seed
16
  ],
17
  [
18
  "A dragon", # prompt
 
19
  "dragon, sky, rock, cloud", # words
20
  42, # seed
21
  ],
22
+ [
23
  "A hot air balloon", # prompt
 
24
  "balloon, sky, water, tree", # words
25
  42, # seed
26
  ]
 
29
  pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda")
30
 
31
  @spaces.GPU(duration=60)
32
+ def process_inputs(prompt, word_list, seed, layer_start_index, timestep_start_index):
33
  print("Processing inputs")
34
+ assert layer_start_index is not None
35
+ assert timestep_start_index is not None
36
+
37
  prompt = prompt.strip()
38
  if not word_list.strip():
39
+ gr.exceptions.InputError("words", "Please enter comma-separated words")
40
 
41
  concepts = [w.strip() for w in word_list.split(",")]
42
 
43
+ if len(concepts) == 0:
44
+ raise gr.exceptions.InputError("words", "Please enter at least 1 concept")
45
+
46
+ if len(concepts) > 9:
47
+ raise gr.exceptions.InputError("words", "Please enter at most 9 concepts")
48
+
49
+ pipeline_output = pipeline.generate_image(
50
+ prompt=prompt,
51
+ concepts=concepts,
52
+ width=1024,
53
+ height=1024,
54
+ seed=seed,
55
+ timesteps=list(range(timestep_start_index, 4)),
56
+ num_inference_steps=4,
57
+ layer_indices=list(range(layer_start_index, 19)),
58
+ softmax=True if len(concepts) > 1 else False
59
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  output_image = pipeline_output.image
62
  concept_heatmaps = pipeline_output.concept_heatmaps
63
+ concept_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in concept_heatmaps]
64
 
65
+ heatmaps_and_labels = [(concept_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))]
66
+ all_images_and_labels = [(output_image, "Generated Image")] + heatmaps_and_labels
 
 
 
 
67
 
68
+ # combined_html = "<div style='display: flex; flex-wrap: wrap; justify-content: center;'>"
69
+ # # Show the output image
70
+ # combined_html += f"""
71
+ # <div style='text-align: center; margin: 5px; padding: 5px;'>
72
+ # <img src='data:image/png;base64,{output_image}' style='width: {IMG_SIZE}px; display: inline-block; height: {IMG_SIZE}px;'>
73
+ # </div>
74
+ # """
75
 
76
+ # for concept, heatmap in zip(concepts, concept_heatmaps):
77
+ # img = heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST)
78
+ # buffered = io.BytesIO()
79
+ # img.save(buffered, format="PNG")
80
+ # img_str = base64.b64encode(buffered.getvalue()).decode()
81
 
82
+ # html = f"""
83
+ # <div style='text-align: center; margin: 5px; padding: 5px; overflow-x: auto; white-space: nowrap;'>
84
+ # <h1 style='margin-bottom: 10px;'>{concept}</h1>
85
+ # <img src='data:image/png;base64,{img_str}' style='width: {IMG_SIZE}px; display: inline-block; height: {IMG_SIZE}px;'>
86
+ # </div>
87
+ # """
88
+
89
+ # combined_html += html
90
+
91
+ # combined_html += "</div>"
92
+
93
+ return all_images_and_labels
94
 
95
  with gr.Blocks(
96
  css="""
 
98
  .title { text-align: center; margin-bottom: 10px; }
99
  .authors { text-align: center; margin-bottom: 10px; }
100
  .affiliations { text-align: center; color: #666; margin-bottom: 10px; }
 
 
 
101
  .abstract { text-align: center; margin-bottom: 40px; }
102
+ .input-row { height: 60px; }
103
  """
104
  ) as demo:
105
  with gr.Column(elem_classes="container"):
 
115
  elem_classes="abstract"
116
  )
117
 
118
+ with gr.Row(equal_height=True, elem_classes="input-row"):
119
+ prompt = gr.Textbox(
120
+ label="Enter your prompt",
121
+ placeholder="Enter your prompt",
122
+ value=EXAMPLES[0][0],
123
+ scale=4,
124
+ # show_label=False
125
+ )
126
+ words = gr.Textbox(
127
+ label="Enter a list of concepts (comma-separated)",
128
+ placeholder="Enter a list of concepts (comma-separated)",
129
+ value=EXAMPLES[0][1],
130
+ scale=4,
131
+ # show_label=False
132
+ )
133
+ submit_btn = gr.Button(
134
+ "Run",
135
+ min_width="100px",
136
+ scale=1
137
+ )
138
+
139
+ # generated_image = gr.Image(label="Generated Image", elem_classes="input-image")
140
+ gallery = gr.Gallery(
141
+ label="Generated images",
142
+ show_label=True,
143
+ elem_id="gallery",
144
+ columns=[5],
145
+ # rows=[1],
146
+ object_fit="contain",
147
+ # height="auto"
148
+ )
149
+ with gr.Accordion("Advanced Settings", open=False):
150
+ seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42)
151
+ layer_start_index = gr.Slider(minimum=0, maximum=18, step=1, label="Layer Start Index", value=10)
152
+ timestep_start_index = gr.Slider(minimum=0, maximum=4, step=1, label="Timestep Start Index", value=2)
153
+
154
 
155
  submit_btn.click(
156
  fn=process_inputs,
157
+ inputs=[prompt, words, seed, layer_start_index, timestep_start_index],
158
+ outputs=[gallery]
159
  )
 
 
 
 
 
160
 
161
+ gr.Examples(examples=EXAMPLES, inputs=[prompt, words, seed, layer_start_index, timestep_start_index], outputs=[gallery], fn=process_inputs, cache_examples=True)
162
+
163
+ # Automatically process the first example on launch
164
+ demo.load(process_inputs, inputs=[prompt, words, seed, layer_start_index, timestep_start_index], outputs=[gallery])
165
+
166
 
167
  if __name__ == "__main__":
168
  demo.launch(max_threads=1)