Pedro Cuenca commited on
Commit
a5a7c8e
·
1 Parent(s): eeca189

Actually perform predictions

Browse files
Files changed (1) hide show
  1. app.py +35 -20
app.py CHANGED
@@ -1,14 +1,24 @@
1
  #!/usr/bin/env python
2
  # coding: utf-8
3
- import random
4
- from typing import List
5
  import gradio as gr
 
 
6
  from collections import defaultdict
 
7
  from functools import partial
 
 
8
  from PIL import Image
9
 
10
  SELECT_LABEL = "Select as seed"
11
 
 
 
 
 
 
 
 
12
  with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
13
  state = gr.Variable({
14
  'selected': -1,
@@ -17,15 +27,21 @@ with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
17
 
18
  def infer_seeded_image(prompt, seed):
19
  print(f"Prompt: {prompt}, seed: {seed}")
20
- return Image.open(f"sample_outputs/seeded_1.png")
21
-
22
- def infer_grid(prompt):
23
- response = defaultdict(list)
24
- for i in range(1, 7):
25
- response["images"].append(Image.open(f"sample_outputs/{i}.png"))
26
- response["seeds"].append(random.randint(0, 2 ** 32 -1))
27
-
28
- return response["images"], response["seeds"]
 
 
 
 
 
 
29
 
30
  def infer(prompt, state):
31
  """
@@ -51,15 +67,6 @@ with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
51
  boxes = [gr.Box.update(visible=v) for v in visible]
52
  return grid_images + [image_with_seed] + boxes + [state]
53
 
54
-
55
- def image_block():
56
- return gr.Image(
57
- interactive=False, show_label=False
58
- ).style(
59
- # border = (True, True, False, True),
60
- rounded = (True, True, False, False),
61
- )
62
-
63
  def update_state(selected_index: int, value, state):
64
  if value == '':
65
  others_value = None
@@ -74,6 +81,14 @@ with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
74
  state["selected"] = -1
75
  return [''] * 6 + [gr.Box.update(visible=True), gr.Box.update(visible=False)] + [state]
76
 
 
 
 
 
 
 
 
 
77
  def radio_block():
78
  radio = gr.Radio(
79
  choices=[SELECT_LABEL], interactive=True, show_label=False,
 
1
  #!/usr/bin/env python
2
  # coding: utf-8
 
 
3
  import gradio as gr
4
+ import random
5
+ import torch
6
  from collections import defaultdict
7
+ from diffusers import DiffusionPipeline
8
  from functools import partial
9
+ from itertools import zip_longest
10
+ from typing import List
11
  from PIL import Image
12
 
13
  SELECT_LABEL = "Select as seed"
14
 
15
+ MODEL_ID = "CompVis/ldm-text2im-large-256"
16
+ STEPS = 50
17
+ ETA = 0.3
18
+ GUIDANCE_SCALE = 12
19
+
20
+ ldm = DiffusionPipeline.from_pretrained(MODEL_ID)
21
+
22
  with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
23
  state = gr.Variable({
24
  'selected': -1,
 
27
 
28
  def infer_seeded_image(prompt, seed):
29
  print(f"Prompt: {prompt}, seed: {seed}")
30
+ images, _ = infer_grid(prompt, n=1, seeds=[seed])
31
+ return images[0]
32
+
33
+ def infer_grid(prompt, n=6, seeds=[]):
34
+ # Unfortunately we have to iterate instead requesting all images at once,
35
+ # because we have no way to get the generation seeds.
36
+ result = defaultdict(list)
37
+ for _, seed in zip_longest(range(n), seeds, fillvalue=None):
38
+ seed = random.randint(0, 2**32 - 1) if seed is None else seed
39
+ print(f"Setting seed {seed}")
40
+ _ = torch.manual_seed(seed)
41
+ images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6)["sample"]
42
+ result["images"].append(images[0])
43
+ result["seeds"].append(seed)
44
+ return result["images"], result["seeds"]
45
 
46
  def infer(prompt, state):
47
  """
 
67
  boxes = [gr.Box.update(visible=v) for v in visible]
68
  return grid_images + [image_with_seed] + boxes + [state]
69
 
 
 
 
 
 
 
 
 
 
70
  def update_state(selected_index: int, value, state):
71
  if value == '':
72
  others_value = None
 
81
  state["selected"] = -1
82
  return [''] * 6 + [gr.Box.update(visible=True), gr.Box.update(visible=False)] + [state]
83
 
84
+ def image_block():
85
+ return gr.Image(
86
+ interactive=False, show_label=False
87
+ ).style(
88
+ # border = (True, True, False, True),
89
+ rounded = (True, True, False, False),
90
+ )
91
+
92
  def radio_block():
93
  radio = gr.Radio(
94
  choices=[SELECT_LABEL], interactive=True, show_label=False,