helblazer811 commited on
Commit
54f7994
·
1 Parent(s): dc35ac5

Fixes to error checking

Browse files
Files changed (1) hide show
  1. app.py +47 -43
app.py CHANGED
@@ -21,7 +21,7 @@ def update_default_concepts(prompt):
21
 
22
  return gr.update(value=default_concepts.get(prompt, []))
23
 
24
- pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell") # , device="cuda:2", offload_model=True)
25
 
26
  def convert_pil_to_bytes(img):
27
  img = img.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST)
@@ -33,45 +33,49 @@ def convert_pil_to_bytes(img):
33
 
34
  @spaces.GPU(duration=60)
35
  def process_inputs(prompt, concepts, seed, layer_start_index, timestep_start_index):
36
- if not prompt:
37
- raise gr.exceptions.InputError("prompt", "Please enter a prompt")
38
-
39
- if not prompt.strip():
40
- raise gr.exceptions.InputError("prompt", "Please enter a prompt")
41
-
42
- prompt = prompt.strip()
43
-
44
- if len(concepts) == 0:
45
- raise gr.exceptions.InputError("words", "Please enter at least 1 concept")
46
-
47
- if len(concepts) > 9:
48
- raise gr.exceptions.InputError("words", "Please enter at most 9 concepts")
49
-
50
- pipeline_output = pipeline.generate_image(
51
- prompt=prompt,
52
- concepts=concepts,
53
- width=1024,
54
- height=1024,
55
- seed=seed,
56
- timesteps=list(range(timestep_start_index, 4)),
57
- num_inference_steps=4,
58
- layer_indices=list(range(layer_start_index, 19)),
59
- softmax=True if len(concepts) > 1 else False
60
- )
61
-
62
- output_image = pipeline_output.image
63
-
64
- output_space_heatmaps = pipeline_output.concept_heatmaps
65
- output_space_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in output_space_heatmaps]
66
- output_space_maps_and_labels = [(output_space_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))]
67
-
68
- cross_attention_heatmaps = pipeline_output.cross_attention_maps
69
- cross_attention_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in cross_attention_heatmaps]
70
- cross_attention_maps_and_labels = [(cross_attention_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))]
71
-
72
- return output_image, \
73
- gr.update(value=output_space_maps_and_labels, columns=len(output_space_maps_and_labels)), \
74
- gr.update(value=cross_attention_maps_and_labels, columns=len(cross_attention_maps_and_labels))
 
 
 
 
75
 
76
  with gr.Blocks(
77
  css="""
@@ -100,7 +104,7 @@ with gr.Blocks(
100
  }
101
  .input-column-label {}
102
  .gallery {
103
- height: 200px;
104
  }
105
  .run-button-column {
106
  width: 100px !important;
@@ -276,7 +280,7 @@ with gr.Blocks(
276
  # columns=3,
277
  rows=1,
278
  object_fit="contain",
279
- # height="200px",
280
  elem_classes="gallery",
281
  elem_id="concept-attention-gallery",
282
  # scale=4
@@ -288,7 +292,7 @@ with gr.Blocks(
288
  # columns=3,
289
  rows=1,
290
  object_fit="contain",
291
- # height="200px",
292
  elem_classes="gallery",
293
  # scale=4
294
  )
 
21
 
22
  return gr.update(value=default_concepts.get(prompt, []))
23
 
24
+ pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell") # , offload_model=True) # , device="cuda:2", offload_model=True)
25
 
26
  def convert_pil_to_bytes(img):
27
  img = img.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST)
 
33
 
34
  @spaces.GPU(duration=60)
35
  def process_inputs(prompt, concepts, seed, layer_start_index, timestep_start_index):
36
+ try:
37
+ if not prompt:
38
+ raise gr.Error("Please enter a prompt", duration=10)
39
+
40
+ if not prompt.strip():
41
+ raise gr.Error("Please enter a prompt", duration=10)
42
+
43
+ prompt = prompt.strip()
44
+
45
+ if len(concepts) == 0:
46
+ raise gr.Error("Please enter at least 1 concept", duration=10)
47
+
48
+ if len(concepts) > 9:
49
+ raise gr.Error("Please enter at most 9 concepts", duration=10)
50
+
51
+ pipeline_output = pipeline.generate_image(
52
+ prompt=prompt,
53
+ concepts=concepts,
54
+ width=1024,
55
+ height=1024,
56
+ seed=seed,
57
+ timesteps=list(range(timestep_start_index, 4)),
58
+ num_inference_steps=4,
59
+ layer_indices=list(range(layer_start_index, 19)),
60
+ softmax=True if len(concepts) > 1 else False
61
+ )
62
+
63
+ output_image = pipeline_output.image
64
+
65
+ output_space_heatmaps = pipeline_output.concept_heatmaps
66
+ output_space_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in output_space_heatmaps]
67
+ output_space_maps_and_labels = [(output_space_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))]
68
+
69
+ cross_attention_heatmaps = pipeline_output.cross_attention_maps
70
+ cross_attention_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in cross_attention_heatmaps]
71
+ cross_attention_maps_and_labels = [(cross_attention_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))]
72
+
73
+ return output_image, \
74
+ gr.update(value=output_space_maps_and_labels, columns=len(output_space_maps_and_labels)), \
75
+ gr.update(value=cross_attention_maps_and_labels, columns=len(cross_attention_maps_and_labels))
76
+
77
+ except gr.Error as e:
78
+ return None, gr.update(value=[], columns=1), gr.update(value=[], columns=1)
79
 
80
  with gr.Blocks(
81
  css="""
 
104
  }
105
  .input-column-label {}
106
  .gallery {
107
+ height: 220px;
108
  }
109
  .run-button-column {
110
  width: 100px !important;
 
280
  # columns=3,
281
  rows=1,
282
  object_fit="contain",
283
+ height="200px",
284
  elem_classes="gallery",
285
  elem_id="concept-attention-gallery",
286
  # scale=4
 
292
  # columns=3,
293
  rows=1,
294
  object_fit="contain",
295
+ height="200px",
296
  elem_classes="gallery",
297
  # scale=4
298
  )