pawlo2013 commited on
Commit
d51b792
Β·
1 Parent(s): f3fb43b

added progress bar

Browse files
Files changed (2) hide show
  1. app.py +23 -50
  2. load_model.py +7 -4
app.py CHANGED
@@ -10,8 +10,6 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
10
  device = "mps" if torch.backends.mps.is_available() else device
11
 
12
  image_size = 128
13
- upscale = False
14
- clicked = False
15
 
16
 
17
  transform = transforms.Compose(
@@ -23,49 +21,26 @@ transform = transforms.Compose(
23
  )
24
 
25
 
26
- def make_scribbles(sketch, scribbles):
27
- # get the value that occurs most often in the scribbles
28
- sketch = transforms.Resize((image_size, image_size))(sketch)
29
- scribbles = transforms.Resize((image_size, image_size))(scribbles)
30
-
31
- grey_tensor = torch.tensor(0.49803922, device=device)
32
-
33
- grey_tensor = grey_tensor.expand(3, image_size, image_size)
34
-
35
- sketch = transforms.ToTensor()(sketch).to(device)
36
- scribbles = transforms.ToTensor()(scribbles).to(device)
37
-
38
- scribble_where_grey_mask = torch.eq(scribbles, grey_tensor)
39
-
40
- merged = torch.where(scribble_where_grey_mask, sketch, scribbles)
41
-
42
- return transforms.Lambda(lambda t: (t * 2) - 1)(sketch), transforms.Lambda(
43
- lambda t: (t * 2) - 1
44
- )(merged)
45
-
46
-
47
- def process_images(sketch, scribbles, sampling_steps, is_scribbles, seed_nr, upscale):
48
- global clicked
49
- clicked = True
50
  w, h = sketch.size
51
 
52
- if is_scribbles:
53
- sketch, scribbles = make_scribbles(sketch, scribbles)
54
-
55
- else:
56
- sketch = transform(sketch.convert("RGB"))
57
- scribbles = transform(scribbles.convert("RGB"))
58
 
59
  if upscale:
60
- output = transforms.Resize((h, w))(
61
- sample(sketch, scribbles, sampling_steps, seed_nr)
62
  )
63
- clicked = False
64
- return output
65
  else:
66
- output = sample(sketch, scribbles, sampling_steps, seed_nr)
67
- clicked = False
68
- return output
69
 
70
 
71
  theme = gr.themes.Monochrome()
@@ -87,7 +62,7 @@ with gr.Blocks(theme=theme) as demo:
87
  "By default the scribbles are assumed to be merged with the sketch, if they appear on a grey background check the box below. "
88
  "</p>"
89
  )
90
- is_scribbles = gr.Checkbox(label="Is Scribbles", value=False)
91
  with gr.Column():
92
  output = gr.Image(type="pil", label="Output")
93
  upscale_info = gr.Markdown(
@@ -96,14 +71,12 @@ with gr.Blocks(theme=theme) as demo:
96
  "</p>"
97
  )
98
  upscale_button = gr.Checkbox(label="Stretch", value=False)
 
99
  with gr.Row():
100
  with gr.Column():
101
  seed_slider = gr.Number(
102
- label="Random Seed 🎲",
103
- value=random.randint(
104
- 1,
105
- 1000,
106
- ),
107
  )
108
 
109
  with gr.Column():
@@ -111,12 +84,12 @@ with gr.Blocks(theme=theme) as demo:
111
  minimum=1,
112
  maximum=250,
113
  step=1,
114
- label="DDPM Sampling Steps πŸ”„",
115
  value=50,
116
  )
117
 
118
  with gr.Row():
119
- generate_button = gr.Button(value="Generate", interactive=not clicked)
120
  with gr.Row():
121
  generate_info = gr.Markdown(
122
  "<p style='text-align: center; font-size: 16px;'>"
@@ -130,13 +103,13 @@ with gr.Blocks(theme=theme) as demo:
130
  sketch_input,
131
  scribbles_input,
132
  sampling_slider,
133
- is_scribbles,
134
  seed_slider,
135
  upscale_button,
136
  ],
137
  outputs=output,
138
- show_progress=True,
 
139
  )
140
 
141
  if __name__ == "__main__":
142
- demo.launch(max_threads=1)
 
10
  device = "mps" if torch.backends.mps.is_available() else device
11
 
12
  image_size = 128
 
 
13
 
14
 
15
  transform = transforms.Compose(
 
21
  )
22
 
23
 
24
+ def process_images(
25
+ sketch,
26
+ scribbles,
27
+ sampling_steps,
28
+ seed_nr,
29
+ upscale,
30
+ progress=gr.Progress(),
31
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  w, h = sketch.size
33
 
34
+ sketch = transform(sketch.convert("RGB"))
35
+ scribbles = transform(scribbles.convert("RGB"))
 
 
 
 
36
 
37
  if upscale:
38
+ return transforms.Resize((h, w))(
39
+ sample(sketch, scribbles, sampling_steps, seed_nr, progress)
40
  )
41
+
 
42
  else:
43
+ return sample(sketch, scribbles, sampling_steps, seed_nr, progress)
 
 
44
 
45
 
46
  theme = gr.themes.Monochrome()
 
62
  "By default the scribbles are assumed to be merged with the sketch, if they appear on a grey background check the box below. "
63
  "</p>"
64
  )
65
+
66
  with gr.Column():
67
  output = gr.Image(type="pil", label="Output")
68
  upscale_info = gr.Markdown(
 
71
  "</p>"
72
  )
73
  upscale_button = gr.Checkbox(label="Stretch", value=False)
74
+
75
  with gr.Row():
76
  with gr.Column():
77
  seed_slider = gr.Number(
78
+ label="Random Seed 🎲 (if the image generated is not to your liking, simply use another seed)",
79
+ value=random.randint(0, 10000),
 
 
 
80
  )
81
 
82
  with gr.Column():
 
84
  minimum=1,
85
  maximum=250,
86
  step=1,
87
+ label="DDPM Sampling Steps πŸ”„ (the higher the number of steps the higher the quality of the images)",
88
  value=50,
89
  )
90
 
91
  with gr.Row():
92
+ generate_button = gr.Button(value="Generate")
93
  with gr.Row():
94
  generate_info = gr.Markdown(
95
  "<p style='text-align: center; font-size: 16px;'>"
 
103
  sketch_input,
104
  scribbles_input,
105
  sampling_slider,
 
106
  seed_slider,
107
  upscale_button,
108
  ],
109
  outputs=output,
110
+ concurrency_limit=1,
111
+ trigger_mode="once",
112
  )
113
 
114
  if __name__ == "__main__":
115
+ demo.queue().launch(max_threads=1)
load_model.py CHANGED
@@ -8,6 +8,7 @@ from torchvision import transforms
8
  import pathlib
9
  from torchvision.utils import save_image
10
  from safetensors.torch import load_model, save_model
 
11
 
12
 
13
  denoising_timesteps = 4000
@@ -61,7 +62,7 @@ else:
61
  raise Exception("No model files found in the folder.")
62
 
63
 
64
- def sample(sketch, scribbles, sampling_steps, seed_nr):
65
  torch.manual_seed(seed_nr)
66
 
67
  noise_scheduler = DDPMScheduler(
@@ -80,9 +81,9 @@ def sample(sketch, scribbles, sampling_steps, seed_nr):
80
 
81
  noise_for_plain = torch.randn_like(sketch, device=device)
82
 
83
- for i, t in tqdm(
84
- enumerate(noise_scheduler.timesteps),
85
- total=len(noise_scheduler.timesteps),
86
  ):
87
  noise_for_plain = noise_scheduler.scale_model_input(noise_for_plain, t).to(
88
  device
@@ -105,6 +106,8 @@ def sample(sketch, scribbles, sampling_steps, seed_nr):
105
  noise_for_plain,
106
  ).prev_sample
107
 
 
 
108
  sample = torch.clamp((noise_for_plain / 2) + 0.5, 0, 1)
109
 
110
  return transforms.ToPILImage()(sample[0].cpu())
 
8
  import pathlib
9
  from torchvision.utils import save_image
10
  from safetensors.torch import load_model, save_model
11
+ import time as tm
12
 
13
 
14
  denoising_timesteps = 4000
 
62
  raise Exception("No model files found in the folder.")
63
 
64
 
65
+ def sample(sketch, scribbles, sampling_steps, seed_nr, progress):
66
  torch.manual_seed(seed_nr)
67
 
68
  noise_scheduler = DDPMScheduler(
 
81
 
82
  noise_for_plain = torch.randn_like(sketch, device=device)
83
 
84
+ for t in progress.tqdm(
85
+ noise_scheduler.timesteps,
86
+ desc="Sampling",
87
  ):
88
  noise_for_plain = noise_scheduler.scale_model_input(noise_for_plain, t).to(
89
  device
 
106
  noise_for_plain,
107
  ).prev_sample
108
 
109
+ tm.sleep(0.01)
110
+
111
  sample = torch.clamp((noise_for_plain / 2) + 0.5, 0, 1)
112
 
113
  return transforms.ToPILImage()(sample[0].cpu())