helblazer811 commited on
Commit
ade193e
·
1 Parent(s): 0f8eff7

Dynamically update gallery num rows.

Browse files
Files changed (1) hide show
  1. app.py +14 -29
app.py CHANGED
@@ -1,12 +1,12 @@
1
- import base64
2
- import io
3
  import spaces
4
  import gradio as gr
5
  from PIL import Image
 
6
 
7
  from concept_attention import ConceptAttentionFluxPipeline
8
 
9
  IMG_SIZE = 250
 
10
 
11
  EXAMPLES = [
12
  [
@@ -65,32 +65,17 @@ def process_inputs(prompt, word_list, seed, layer_start_index, timestep_start_in
65
  heatmaps_and_labels = [(concept_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))]
66
  all_images_and_labels = [(output_image, "Generated Image")] + heatmaps_and_labels
67
 
68
- # combined_html = "<div style='display: flex; flex-wrap: wrap; justify-content: center;'>"
69
- # # Show the output image
70
- # combined_html += f"""
71
- # <div style='text-align: center; margin: 5px; padding: 5px;'>
72
- # <img src='data:image/png;base64,{output_image}' style='width: {IMG_SIZE}px; display: inline-block; height: {IMG_SIZE}px;'>
73
- # </div>
74
- # """
75
 
76
- # for concept, heatmap in zip(concepts, concept_heatmaps):
77
- # img = heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST)
78
- # buffered = io.BytesIO()
79
- # img.save(buffered, format="PNG")
80
- # img_str = base64.b64encode(buffered.getvalue()).decode()
81
-
82
- # html = f"""
83
- # <div style='text-align: center; margin: 5px; padding: 5px; overflow-x: auto; white-space: nowrap;'>
84
- # <h1 style='margin-bottom: 10px;'>{concept}</h1>
85
- # <img src='data:image/png;base64,{img_str}' style='width: {IMG_SIZE}px; display: inline-block; height: {IMG_SIZE}px;'>
86
- # </div>
87
- # """
88
-
89
- # combined_html += html
90
-
91
- # combined_html += "</div>"
92
 
93
- return all_images_and_labels
94
 
95
  with gr.Blocks(
96
  css="""
@@ -143,9 +128,9 @@ with gr.Blocks(
143
  gallery = gr.Gallery(
144
  label="Generated images",
145
  show_label=True,
146
- elem_id="gallery",
147
- columns=[5],
148
- # rows=[1],
149
  object_fit="contain",
150
  # height="auto"
151
  )
 
 
 
1
  import spaces
2
  import gradio as gr
3
  from PIL import Image
4
+ import math
5
 
6
  from concept_attention import ConceptAttentionFluxPipeline
7
 
8
  IMG_SIZE = 250
9
+ COLUMNS = 5
10
 
11
  EXAMPLES = [
12
  [
 
65
  heatmaps_and_labels = [(concept_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))]
66
  all_images_and_labels = [(output_image, "Generated Image")] + heatmaps_and_labels
67
 
68
+ num_rows = math.ceil(len(all_images_and_labels) / COLUMNS)
 
 
 
 
 
 
69
 
70
+ gallery = gr.Gallery(
71
+ label="Generated images",
72
+ show_label=True,
73
+ columns=[COLUMNS],
74
+ rows=[num_rows],
75
+ object_fit="contain"
76
+ )
 
 
 
 
 
 
 
 
 
77
 
78
+ return gallery
79
 
80
  with gr.Blocks(
81
  css="""
 
128
  gallery = gr.Gallery(
129
  label="Generated images",
130
  show_label=True,
131
+ # elem_id="gallery",
132
+ columns=[COLUMNS],
133
+ rows=[1],
134
  object_fit="contain",
135
  # height="auto"
136
  )