Spaces:
Running
Running
added progress bar
Browse files- app.py +23 -50
- load_model.py +7 -4
app.py
CHANGED
@@ -10,8 +10,6 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
10 |
device = "mps" if torch.backends.mps.is_available() else device
|
11 |
|
12 |
image_size = 128
|
13 |
-
upscale = False
|
14 |
-
clicked = False
|
15 |
|
16 |
|
17 |
transform = transforms.Compose(
|
@@ -23,49 +21,26 @@ transform = transforms.Compose(
|
|
23 |
)
|
24 |
|
25 |
|
26 |
-
def
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
sketch = transforms.ToTensor()(sketch).to(device)
|
36 |
-
scribbles = transforms.ToTensor()(scribbles).to(device)
|
37 |
-
|
38 |
-
scribble_where_grey_mask = torch.eq(scribbles, grey_tensor)
|
39 |
-
|
40 |
-
merged = torch.where(scribble_where_grey_mask, sketch, scribbles)
|
41 |
-
|
42 |
-
return transforms.Lambda(lambda t: (t * 2) - 1)(sketch), transforms.Lambda(
|
43 |
-
lambda t: (t * 2) - 1
|
44 |
-
)(merged)
|
45 |
-
|
46 |
-
|
47 |
-
def process_images(sketch, scribbles, sampling_steps, is_scribbles, seed_nr, upscale):
|
48 |
-
global clicked
|
49 |
-
clicked = True
|
50 |
w, h = sketch.size
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
else:
|
56 |
-
sketch = transform(sketch.convert("RGB"))
|
57 |
-
scribbles = transform(scribbles.convert("RGB"))
|
58 |
|
59 |
if upscale:
|
60 |
-
|
61 |
-
sample(sketch, scribbles, sampling_steps, seed_nr)
|
62 |
)
|
63 |
-
|
64 |
-
return output
|
65 |
else:
|
66 |
-
|
67 |
-
clicked = False
|
68 |
-
return output
|
69 |
|
70 |
|
71 |
theme = gr.themes.Monochrome()
|
@@ -87,7 +62,7 @@ with gr.Blocks(theme=theme) as demo:
|
|
87 |
"By default the scribbles are assumed to be merged with the sketch, if they appear on a grey background check the box below. "
|
88 |
"</p>"
|
89 |
)
|
90 |
-
|
91 |
with gr.Column():
|
92 |
output = gr.Image(type="pil", label="Output")
|
93 |
upscale_info = gr.Markdown(
|
@@ -96,14 +71,12 @@ with gr.Blocks(theme=theme) as demo:
|
|
96 |
"</p>"
|
97 |
)
|
98 |
upscale_button = gr.Checkbox(label="Stretch", value=False)
|
|
|
99 |
with gr.Row():
|
100 |
with gr.Column():
|
101 |
seed_slider = gr.Number(
|
102 |
-
label="Random Seed π²",
|
103 |
-
value=random.randint(
|
104 |
-
1,
|
105 |
-
1000,
|
106 |
-
),
|
107 |
)
|
108 |
|
109 |
with gr.Column():
|
@@ -111,12 +84,12 @@ with gr.Blocks(theme=theme) as demo:
|
|
111 |
minimum=1,
|
112 |
maximum=250,
|
113 |
step=1,
|
114 |
-
label="DDPM Sampling Steps π",
|
115 |
value=50,
|
116 |
)
|
117 |
|
118 |
with gr.Row():
|
119 |
-
generate_button = gr.Button(value="Generate"
|
120 |
with gr.Row():
|
121 |
generate_info = gr.Markdown(
|
122 |
"<p style='text-align: center; font-size: 16px;'>"
|
@@ -130,13 +103,13 @@ with gr.Blocks(theme=theme) as demo:
|
|
130 |
sketch_input,
|
131 |
scribbles_input,
|
132 |
sampling_slider,
|
133 |
-
is_scribbles,
|
134 |
seed_slider,
|
135 |
upscale_button,
|
136 |
],
|
137 |
outputs=output,
|
138 |
-
|
|
|
139 |
)
|
140 |
|
141 |
if __name__ == "__main__":
|
142 |
-
demo.launch(max_threads=1)
|
|
|
10 |
device = "mps" if torch.backends.mps.is_available() else device
|
11 |
|
12 |
image_size = 128
|
|
|
|
|
13 |
|
14 |
|
15 |
transform = transforms.Compose(
|
|
|
21 |
)
|
22 |
|
23 |
|
24 |
+
def process_images(
|
25 |
+
sketch,
|
26 |
+
scribbles,
|
27 |
+
sampling_steps,
|
28 |
+
seed_nr,
|
29 |
+
upscale,
|
30 |
+
progress=gr.Progress(),
|
31 |
+
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
w, h = sketch.size
|
33 |
|
34 |
+
sketch = transform(sketch.convert("RGB"))
|
35 |
+
scribbles = transform(scribbles.convert("RGB"))
|
|
|
|
|
|
|
|
|
36 |
|
37 |
if upscale:
|
38 |
+
return transforms.Resize((h, w))(
|
39 |
+
sample(sketch, scribbles, sampling_steps, seed_nr, progress)
|
40 |
)
|
41 |
+
|
|
|
42 |
else:
|
43 |
+
return sample(sketch, scribbles, sampling_steps, seed_nr, progress)
|
|
|
|
|
44 |
|
45 |
|
46 |
theme = gr.themes.Monochrome()
|
|
|
62 |
"By default the scribbles are assumed to be merged with the sketch, if they appear on a grey background check the box below. "
|
63 |
"</p>"
|
64 |
)
|
65 |
+
|
66 |
with gr.Column():
|
67 |
output = gr.Image(type="pil", label="Output")
|
68 |
upscale_info = gr.Markdown(
|
|
|
71 |
"</p>"
|
72 |
)
|
73 |
upscale_button = gr.Checkbox(label="Stretch", value=False)
|
74 |
+
|
75 |
with gr.Row():
|
76 |
with gr.Column():
|
77 |
seed_slider = gr.Number(
|
78 |
+
label="Random Seed π² (if the image generated is not to your liking, simply use another seed)",
|
79 |
+
value=random.randint(0, 10000),
|
|
|
|
|
|
|
80 |
)
|
81 |
|
82 |
with gr.Column():
|
|
|
84 |
minimum=1,
|
85 |
maximum=250,
|
86 |
step=1,
|
87 |
+
label="DDPM Sampling Steps π (the higher the number of steps the higher the quality of the images)",
|
88 |
value=50,
|
89 |
)
|
90 |
|
91 |
with gr.Row():
|
92 |
+
generate_button = gr.Button(value="Generate")
|
93 |
with gr.Row():
|
94 |
generate_info = gr.Markdown(
|
95 |
"<p style='text-align: center; font-size: 16px;'>"
|
|
|
103 |
sketch_input,
|
104 |
scribbles_input,
|
105 |
sampling_slider,
|
|
|
106 |
seed_slider,
|
107 |
upscale_button,
|
108 |
],
|
109 |
outputs=output,
|
110 |
+
concurrency_limit=1,
|
111 |
+
trigger_mode="once",
|
112 |
)
|
113 |
|
114 |
if __name__ == "__main__":
|
115 |
+
demo.queue().launch(max_threads=1)
|
load_model.py
CHANGED
@@ -8,6 +8,7 @@ from torchvision import transforms
|
|
8 |
import pathlib
|
9 |
from torchvision.utils import save_image
|
10 |
from safetensors.torch import load_model, save_model
|
|
|
11 |
|
12 |
|
13 |
denoising_timesteps = 4000
|
@@ -61,7 +62,7 @@ else:
|
|
61 |
raise Exception("No model files found in the folder.")
|
62 |
|
63 |
|
64 |
-
def sample(sketch, scribbles, sampling_steps, seed_nr):
|
65 |
torch.manual_seed(seed_nr)
|
66 |
|
67 |
noise_scheduler = DDPMScheduler(
|
@@ -80,9 +81,9 @@ def sample(sketch, scribbles, sampling_steps, seed_nr):
|
|
80 |
|
81 |
noise_for_plain = torch.randn_like(sketch, device=device)
|
82 |
|
83 |
-
for
|
84 |
-
|
85 |
-
|
86 |
):
|
87 |
noise_for_plain = noise_scheduler.scale_model_input(noise_for_plain, t).to(
|
88 |
device
|
@@ -105,6 +106,8 @@ def sample(sketch, scribbles, sampling_steps, seed_nr):
|
|
105 |
noise_for_plain,
|
106 |
).prev_sample
|
107 |
|
|
|
|
|
108 |
sample = torch.clamp((noise_for_plain / 2) + 0.5, 0, 1)
|
109 |
|
110 |
return transforms.ToPILImage()(sample[0].cpu())
|
|
|
8 |
import pathlib
|
9 |
from torchvision.utils import save_image
|
10 |
from safetensors.torch import load_model, save_model
|
11 |
+
import time as tm
|
12 |
|
13 |
|
14 |
denoising_timesteps = 4000
|
|
|
62 |
raise Exception("No model files found in the folder.")
|
63 |
|
64 |
|
65 |
+
def sample(sketch, scribbles, sampling_steps, seed_nr, progress):
|
66 |
torch.manual_seed(seed_nr)
|
67 |
|
68 |
noise_scheduler = DDPMScheduler(
|
|
|
81 |
|
82 |
noise_for_plain = torch.randn_like(sketch, device=device)
|
83 |
|
84 |
+
for t in progress.tqdm(
|
85 |
+
noise_scheduler.timesteps,
|
86 |
+
desc="Sampling",
|
87 |
):
|
88 |
noise_for_plain = noise_scheduler.scale_model_input(noise_for_plain, t).to(
|
89 |
device
|
|
|
106 |
noise_for_plain,
|
107 |
).prev_sample
|
108 |
|
109 |
+
tm.sleep(0.01)
|
110 |
+
|
111 |
sample = torch.clamp((noise_for_plain / 2) + 0.5, 0, 1)
|
112 |
|
113 |
return transforms.ToPILImage()(sample[0].cpu())
|