RishabA commited on
Commit
c74d396
·
verified ·
1 Parent(s): ddc10aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -12
app.py CHANGED
@@ -1,18 +1,22 @@
1
  import torch
2
  import gradio as gr
3
- from model import UNet, VQVAE, sample_ddpm_inference
 
 
 
 
4
  from huggingface_hub import hf_hub_download
5
  import json
6
 
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
 
9
  config_path = hf_hub_download(
10
  repo_id="RishabA/celeba-cond-ddpm", filename="config.json"
11
  )
12
  with open(config_path, "r") as f:
13
  config = json.load(f)
14
 
15
- # Download checkpoint files. Adjust file paths if needed.
16
  ldm_ckpt_path = hf_hub_download(
17
  repo_id="RishabA/celeba-cond-ddpm", filename="celebhq/ddpm_ckpt_class_cond.pth"
18
  )
@@ -20,13 +24,12 @@ vae_ckpt_path = hf_hub_download(
20
  repo_id="RishabA/celeba-cond-ddpm", filename="celebhq/vqvae_autoencoder_ckpt.pth"
21
  )
22
 
23
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  unet = UNet(config["autoencoder_params"]["z_channels"], config["ldm_params"]).to(device)
25
  vae = VQVAE(
26
  config["dataset_params"]["image_channels"], config["autoencoder_params"]
27
  ).to(device)
28
 
29
- # Load the pretrained weights
30
  unet_state = torch.load(ldm_ckpt_path, map_location=device)
31
  unet.load_state_dict(unet_state["model_state_dict"])
32
 
@@ -37,23 +40,21 @@ unet.eval()
37
  vae.eval()
38
 
39
  print("Model and checkpoints loaded successfully!")
40
- print(unet)
41
- print(vae)
42
 
43
 
44
  def generate_image(text_prompt, mask_upload):
45
  """
46
- text_prompt: A text prompt provided by the user.
47
  mask_upload: Either a PIL image (uploaded) or None.
48
- guidance_scale: Float slider setting for classifier-free guidance.
 
 
49
  """
 
50
  return sample_ddpm_inference(unet, vae, text_prompt, mask_upload, device)
51
 
52
 
53
  css_str = """
54
- body {
55
- background-color: #f7f7f7;
56
- }
57
  .title {
58
  font-size: 48px;
59
  text-align: center;
@@ -69,7 +70,7 @@ body {
69
  with gr.Blocks(css=css_str) as demo:
70
  gr.Markdown("<div class='title'>Conditioned Latent Diffusion with CelebA</div>")
71
  gr.Markdown(
72
- "<div class='description'>Enter a text prompt and (optionally) upload a mask image for conditioning; the model will generate an image accordingly.</div>"
73
  )
74
  with gr.Row():
75
  text_input = gr.Textbox(
@@ -81,10 +82,12 @@ with gr.Blocks(css=css_str) as demo:
81
  generate_button = gr.Button("Generate Image")
82
  output_image = gr.Image(label="Generated Image", type="pil")
83
 
 
84
  generate_button.click(
85
  fn=generate_image,
86
  inputs=[text_input, mask_input],
87
  outputs=output_image,
 
88
  )
89
 
90
  if __name__ == "__main__":
 
1
  import torch
2
  import gradio as gr
3
+ from model import (
4
+ UNet,
5
+ VQVAE,
6
+ sample_ddpm_inference,
7
+ )
8
  from huggingface_hub import hf_hub_download
9
  import json
10
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
+ # Download config and checkpoint files from HF Hub
14
  config_path = hf_hub_download(
15
  repo_id="RishabA/celeba-cond-ddpm", filename="config.json"
16
  )
17
  with open(config_path, "r") as f:
18
  config = json.load(f)
19
 
 
20
  ldm_ckpt_path = hf_hub_download(
21
  repo_id="RishabA/celeba-cond-ddpm", filename="celebhq/ddpm_ckpt_class_cond.pth"
22
  )
 
24
  repo_id="RishabA/celeba-cond-ddpm", filename="celebhq/vqvae_autoencoder_ckpt.pth"
25
  )
26
 
27
+ # Instantiate and load the models
28
  unet = UNet(config["autoencoder_params"]["z_channels"], config["ldm_params"]).to(device)
29
  vae = VQVAE(
30
  config["dataset_params"]["image_channels"], config["autoencoder_params"]
31
  ).to(device)
32
 
 
33
  unet_state = torch.load(ldm_ckpt_path, map_location=device)
34
  unet.load_state_dict(unet_state["model_state_dict"])
35
 
 
40
  vae.eval()
41
 
42
  print("Model and checkpoints loaded successfully!")
 
 
43
 
44
 
45
  def generate_image(text_prompt, mask_upload):
46
  """
47
+ text_prompt: Text prompt provided by the user.
48
  mask_upload: Either a PIL image (uploaded) or None.
49
+
50
+ This function returns a generator that yields an intermediate
51
+ decoded image at every timestep from the diffusion process.
52
  """
53
+ # sample_ddpm_inference is assumed to be a generator function (using yield)
54
  return sample_ddpm_inference(unet, vae, text_prompt, mask_upload, device)
55
 
56
 
57
  css_str = """
 
 
 
58
  .title {
59
  font-size: 48px;
60
  text-align: center;
 
70
  with gr.Blocks(css=css_str) as demo:
71
  gr.Markdown("<div class='title'>Conditioned Latent Diffusion with CelebA</div>")
72
  gr.Markdown(
73
+ "<div class='description'>Enter a text prompt and (optionally) upload a mask image for conditioning; the generated image will update as the reverse diffusion progresses.</div>"
74
  )
75
  with gr.Row():
76
  text_input = gr.Textbox(
 
82
  generate_button = gr.Button("Generate Image")
83
  output_image = gr.Image(label="Generated Image", type="pil")
84
 
85
+ # Adding stream=True allows Gradio to process the generator output
86
  generate_button.click(
87
  fn=generate_image,
88
  inputs=[text_input, mask_input],
89
  outputs=output_image,
90
+ stream=True,
91
  )
92
 
93
  if __name__ == "__main__":