QHL067 commited on
Commit
c47371b
·
1 Parent(s): 585b792
Files changed (1) hide show
  1. app.py +35 -23
app.py CHANGED
@@ -150,6 +150,7 @@ def infer(
150
  guidance_scale,
151
  num_inference_steps,
152
  num_of_interpolation,
 
153
  save_gpu_memory=True,
154
  progress=gr.Progress(track_tqdm=True),
155
  ):
@@ -164,7 +165,10 @@ def infer(
164
  prompt_dict = {"prompt_1": prompt1, "prompt_2": prompt2}
165
  for key, value in prompt_dict.items():
166
  assert value is not None, f"{key} must not be None."
167
- assert num_of_interpolation >= 5, "For linear interpolation, please sample at least five images."
 
 
 
168
 
169
  # Get text embeddings and tokens.
170
  _context, _token_mask, _token, _caption = get_caption(
@@ -181,10 +185,10 @@ def infer(
181
  # Prepare the initial latent representations based on the number of interpolations.
182
  if num_of_interpolation == 3:
183
  # Addition or subtraction mode.
184
- if config.prompt_a is not None:
185
  assert config.prompt_s is None, "Only one of prompt_a or prompt_s should be provided."
186
  z_init_temp = _z_init[0] + _z_init[1]
187
- elif config.prompt_s is not None:
188
  assert config.prompt_a is None, "Only one of prompt_a or prompt_s should be provided."
189
  z_init_temp = _z_init[0] - _z_init[1]
190
  else:
@@ -194,10 +198,7 @@ def infer(
194
  _z_init[2] = (z_init_temp - mean) / std
195
 
196
  elif num_of_interpolation == 4:
197
- z_init_temp = _z_init[0] + _z_init[1] - _z_init[2]
198
- mean = z_init_temp.mean()
199
- std = z_init_temp.std()
200
- _z_init[3] = (z_init_temp - mean) / std
201
 
202
  elif num_of_interpolation >= 5:
203
  tensor_a = _z_init[0]
@@ -244,21 +245,25 @@ def infer(
244
  to_pil = ToPILImage()
245
  pil_images = [to_pil(img) for img in samples]
246
 
247
- first_image = pil_images[0]
248
- last_image = pil_images[-1]
 
 
 
 
249
 
250
- gif_buffer = io.BytesIO()
251
- pil_images[0].save(gif_buffer, format="GIF", save_all=True, append_images=pil_images[1:], duration=200, loop=0)
252
- gif_buffer.seek(0)
253
- gif_bytes = gif_buffer.read()
254
 
255
- # Save the GIF bytes to a temporary file and get its path
256
- temp_gif = tempfile.NamedTemporaryFile(delete=False, suffix=".gif")
257
- temp_gif.write(gif_bytes)
258
- temp_gif.close()
259
- gif_path = temp_gif.name
260
 
261
- return first_image, last_image, gif_path, seed
262
  # return first_image, last_image, seed
263
 
264
 
@@ -269,7 +274,7 @@ def infer(
269
  # ]
270
 
271
  def infer_tab1(prompt1, prompt2, seed, randomize_seed, guidance_scale, num_inference_steps, num_of_interpolation):
272
- default_op = "Addition"
273
  return infer(prompt1, prompt2, seed, randomize_seed, guidance_scale, num_inference_steps, num_of_interpolation, default_op)
274
 
275
  # Wrapper for Tab 2: Uses operation_mode and fixes num_of_interpolation to 3.
@@ -281,6 +286,10 @@ examples_1 = [
281
  ["A robot cooking dinner in the kitchen", "An orange cat wearing sunglasses on a ship"],
282
  ]
283
 
 
 
 
 
284
  css = """
285
  #col-container {
286
  margin: 0 auto;
@@ -464,9 +473,10 @@ with gr.Blocks(css=css) as demo:
464
  prompt2_tab1 = gr.Text(placeholder="Prompt for second image", label="Prompt 2")
465
  seed_tab1 = gr.Slider(minimum=0, maximum=MAX_SEED, step=1, value=0, label="Seed")
466
  randomize_seed_tab1 = gr.Checkbox(label="Randomize seed", value=True)
467
- guidance_scale_tab1 = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=7.0, label="Guidance Scale")
468
- num_inference_steps_tab1 = gr.Slider(minimum=1, maximum=50, step=1, value=25, label="Number of Inference Steps")
469
- num_of_interpolation_tab1 = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Number of Images for Interpolation")
 
470
  run_button_tab1 = gr.Button("Run")
471
 
472
  first_image_output_tab1 = gr.Image(label="Image of the first prompt")
@@ -520,6 +530,8 @@ with gr.Blocks(css=css) as demo:
520
  outputs=[first_image_output_tab2, last_image_output_tab2, gif_output_tab2, seed_tab2]
521
  )
522
 
 
 
523
 
524
  if __name__ == "__main__":
525
  demo.launch()
 
150
  guidance_scale,
151
  num_inference_steps,
152
  num_of_interpolation,
153
+ operation_mode,
154
  save_gpu_memory=True,
155
  progress=gr.Progress(track_tqdm=True),
156
  ):
 
165
  prompt_dict = {"prompt_1": prompt1, "prompt_2": prompt2}
166
  for key, value in prompt_dict.items():
167
  assert value is not None, f"{key} must not be None."
168
+ if operation_mode != 'Interpolation':
169
+ assert num_of_interpolation >= 5, "For linear interpolation, please sample at least five images."
170
+ else:
171
+ assert num_of_interpolation == 3, "For arithmetic, please sample three images."
172
 
173
  # Get text embeddings and tokens.
174
  _context, _token_mask, _token, _caption = get_caption(
 
185
  # Prepare the initial latent representations based on the number of interpolations.
186
  if num_of_interpolation == 3:
187
  # Addition or subtraction mode.
188
+ if operation_mode == 'Addition':
189
  assert config.prompt_s is None, "Only one of prompt_a or prompt_s should be provided."
190
  z_init_temp = _z_init[0] + _z_init[1]
191
+ elif operation_mode == 'Subtraction':
192
  assert config.prompt_a is None, "Only one of prompt_a or prompt_s should be provided."
193
  z_init_temp = _z_init[0] - _z_init[1]
194
  else:
 
198
  _z_init[2] = (z_init_temp - mean) / std
199
 
200
  elif num_of_interpolation == 4:
201
+ raise ValueError("Unsupported number of interpolations.")
 
 
 
202
 
203
  elif num_of_interpolation >= 5:
204
  tensor_a = _z_init[0]
 
245
  to_pil = ToPILImage()
246
  pil_images = [to_pil(img) for img in samples]
247
 
248
+ if num_of_interpolation == 3:
249
+ return pil_images[0], pil_images[1], pil_images[2], seed
250
+
251
+ else:
252
+ first_image = pil_images[0]
253
+ last_image = pil_images[-1]
254
 
255
+ gif_buffer = io.BytesIO()
256
+ pil_images[0].save(gif_buffer, format="GIF", save_all=True, append_images=pil_images[1:], duration=200, loop=0)
257
+ gif_buffer.seek(0)
258
+ gif_bytes = gif_buffer.read()
259
 
260
+ # Save the GIF bytes to a temporary file and get its path
261
+ temp_gif = tempfile.NamedTemporaryFile(delete=False, suffix=".gif")
262
+ temp_gif.write(gif_bytes)
263
+ temp_gif.close()
264
+ gif_path = temp_gif.name
265
 
266
+ return first_image, last_image, gif_path, seed
267
  # return first_image, last_image, seed
268
 
269
 
 
274
  # ]
275
 
276
  def infer_tab1(prompt1, prompt2, seed, randomize_seed, guidance_scale, num_inference_steps, num_of_interpolation):
277
+ default_op = "Interpolation"
278
  return infer(prompt1, prompt2, seed, randomize_seed, guidance_scale, num_inference_steps, num_of_interpolation, default_op)
279
 
280
  # Wrapper for Tab 2: Uses operation_mode and fixes num_of_interpolation to 3.
 
286
  ["A robot cooking dinner in the kitchen", "An orange cat wearing sunglasses on a ship"],
287
  ]
288
 
289
+ examples_2 = [
290
+ ["A corgi in the park", "red hat"],
291
+ ]
292
+
293
  css = """
294
  #col-container {
295
  margin: 0 auto;
 
473
  prompt2_tab1 = gr.Text(placeholder="Prompt for second image", label="Prompt 2")
474
  seed_tab1 = gr.Slider(minimum=0, maximum=MAX_SEED, step=1, value=0, label="Seed")
475
  randomize_seed_tab1 = gr.Checkbox(label="Randomize seed", value=True)
476
+ with gr.Accordion("Advanced Settings", open=False):
477
+ guidance_scale_tab1 = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=7.0, label="Guidance Scale")
478
+ num_inference_steps_tab1 = gr.Slider(minimum=1, maximum=50, step=1, value=25, label="Number of Inference Steps")
479
+ num_of_interpolation_tab1 = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Number of Images for Interpolation")
480
  run_button_tab1 = gr.Button("Run")
481
 
482
  first_image_output_tab1 = gr.Image(label="Image of the first prompt")
 
530
  outputs=[first_image_output_tab2, last_image_output_tab2, gif_output_tab2, seed_tab2]
531
  )
532
 
533
+ gr.Examples(examples=examples_2, inputs=[prompt1_tab2, prompt2_tab2])
534
+
535
 
536
  if __name__ == "__main__":
537
  demo.launch()