Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,18 +1,22 @@
|
|
1 |
import torch
|
2 |
import gradio as gr
|
3 |
-
from model import
|
|
|
|
|
|
|
|
|
4 |
from huggingface_hub import hf_hub_download
|
5 |
import json
|
6 |
|
7 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
8 |
|
|
|
9 |
config_path = hf_hub_download(
|
10 |
repo_id="RishabA/celeba-cond-ddpm", filename="config.json"
|
11 |
)
|
12 |
with open(config_path, "r") as f:
|
13 |
config = json.load(f)
|
14 |
|
15 |
-
# Download checkpoint files. Adjust file paths if needed.
|
16 |
ldm_ckpt_path = hf_hub_download(
|
17 |
repo_id="RishabA/celeba-cond-ddpm", filename="celebhq/ddpm_ckpt_class_cond.pth"
|
18 |
)
|
@@ -20,13 +24,12 @@ vae_ckpt_path = hf_hub_download(
|
|
20 |
repo_id="RishabA/celeba-cond-ddpm", filename="celebhq/vqvae_autoencoder_ckpt.pth"
|
21 |
)
|
22 |
|
23 |
-
|
24 |
unet = UNet(config["autoencoder_params"]["z_channels"], config["ldm_params"]).to(device)
|
25 |
vae = VQVAE(
|
26 |
config["dataset_params"]["image_channels"], config["autoencoder_params"]
|
27 |
).to(device)
|
28 |
|
29 |
-
# Load the pretrained weights
|
30 |
unet_state = torch.load(ldm_ckpt_path, map_location=device)
|
31 |
unet.load_state_dict(unet_state["model_state_dict"])
|
32 |
|
@@ -37,23 +40,21 @@ unet.eval()
|
|
37 |
vae.eval()
|
38 |
|
39 |
print("Model and checkpoints loaded successfully!")
|
40 |
-
print(unet)
|
41 |
-
print(vae)
|
42 |
|
43 |
|
44 |
def generate_image(text_prompt, mask_upload):
|
45 |
"""
|
46 |
-
text_prompt:
|
47 |
mask_upload: Either a PIL image (uploaded) or None.
|
48 |
-
|
|
|
|
|
49 |
"""
|
|
|
50 |
return sample_ddpm_inference(unet, vae, text_prompt, mask_upload, device)
|
51 |
|
52 |
|
53 |
css_str = """
|
54 |
-
body {
|
55 |
-
background-color: #f7f7f7;
|
56 |
-
}
|
57 |
.title {
|
58 |
font-size: 48px;
|
59 |
text-align: center;
|
@@ -69,7 +70,7 @@ body {
|
|
69 |
with gr.Blocks(css=css_str) as demo:
|
70 |
gr.Markdown("<div class='title'>Conditioned Latent Diffusion with CelebA</div>")
|
71 |
gr.Markdown(
|
72 |
-
"<div class='description'>Enter a text prompt and (optionally) upload a mask image for conditioning; the
|
73 |
)
|
74 |
with gr.Row():
|
75 |
text_input = gr.Textbox(
|
@@ -81,10 +82,12 @@ with gr.Blocks(css=css_str) as demo:
|
|
81 |
generate_button = gr.Button("Generate Image")
|
82 |
output_image = gr.Image(label="Generated Image", type="pil")
|
83 |
|
|
|
84 |
generate_button.click(
|
85 |
fn=generate_image,
|
86 |
inputs=[text_input, mask_input],
|
87 |
outputs=output_image,
|
|
|
88 |
)
|
89 |
|
90 |
if __name__ == "__main__":
|
|
|
1 |
import torch
|
2 |
import gradio as gr
|
3 |
+
from model import (
|
4 |
+
UNet,
|
5 |
+
VQVAE,
|
6 |
+
sample_ddpm_inference,
|
7 |
+
)
|
8 |
from huggingface_hub import hf_hub_download
|
9 |
import json
|
10 |
|
11 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
|
13 |
+
# Download config and checkpoint files from HF Hub
|
14 |
config_path = hf_hub_download(
|
15 |
repo_id="RishabA/celeba-cond-ddpm", filename="config.json"
|
16 |
)
|
17 |
with open(config_path, "r") as f:
|
18 |
config = json.load(f)
|
19 |
|
|
|
20 |
ldm_ckpt_path = hf_hub_download(
|
21 |
repo_id="RishabA/celeba-cond-ddpm", filename="celebhq/ddpm_ckpt_class_cond.pth"
|
22 |
)
|
|
|
24 |
repo_id="RishabA/celeba-cond-ddpm", filename="celebhq/vqvae_autoencoder_ckpt.pth"
|
25 |
)
|
26 |
|
27 |
+
# Instantiate and load the models
|
28 |
unet = UNet(config["autoencoder_params"]["z_channels"], config["ldm_params"]).to(device)
|
29 |
vae = VQVAE(
|
30 |
config["dataset_params"]["image_channels"], config["autoencoder_params"]
|
31 |
).to(device)
|
32 |
|
|
|
33 |
unet_state = torch.load(ldm_ckpt_path, map_location=device)
|
34 |
unet.load_state_dict(unet_state["model_state_dict"])
|
35 |
|
|
|
40 |
vae.eval()
|
41 |
|
42 |
print("Model and checkpoints loaded successfully!")
|
|
|
|
|
43 |
|
44 |
|
45 |
def generate_image(text_prompt, mask_upload):
|
46 |
"""
|
47 |
+
text_prompt: Text prompt provided by the user.
|
48 |
mask_upload: Either a PIL image (uploaded) or None.
|
49 |
+
|
50 |
+
This function returns a generator that yields an intermediate
|
51 |
+
decoded image at every timestep from the diffusion process.
|
52 |
"""
|
53 |
+
# sample_ddpm_inference is assumed to be a generator function (using yield)
|
54 |
return sample_ddpm_inference(unet, vae, text_prompt, mask_upload, device)
|
55 |
|
56 |
|
57 |
css_str = """
|
|
|
|
|
|
|
58 |
.title {
|
59 |
font-size: 48px;
|
60 |
text-align: center;
|
|
|
70 |
with gr.Blocks(css=css_str) as demo:
|
71 |
gr.Markdown("<div class='title'>Conditioned Latent Diffusion with CelebA</div>")
|
72 |
gr.Markdown(
|
73 |
+
"<div class='description'>Enter a text prompt and (optionally) upload a mask image for conditioning; the generated image will update as the reverse diffusion progresses.</div>"
|
74 |
)
|
75 |
with gr.Row():
|
76 |
text_input = gr.Textbox(
|
|
|
82 |
generate_button = gr.Button("Generate Image")
|
83 |
output_image = gr.Image(label="Generated Image", type="pil")
|
84 |
|
85 |
+
# Adding stream=True allows Gradio to process the generator output
|
86 |
generate_button.click(
|
87 |
fn=generate_image,
|
88 |
inputs=[text_input, mask_input],
|
89 |
outputs=output_image,
|
90 |
+
stream=True,
|
91 |
)
|
92 |
|
93 |
if __name__ == "__main__":
|