RishabA commited on
Commit
c39621e
·
verified ·
1 Parent(s): 036212e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -58
app.py CHANGED
@@ -55,12 +55,11 @@ vae.eval()
55
  print("Model and checkpoints loaded successfully!")
56
 
57
 
58
- def sample_ddpm_inference(text_prompt):
59
  """
60
  Given a text prompt and (optionally) an image condition (as a PIL image),
61
  sample from the diffusion model and return a generated image (PIL image).
62
  """
63
- mask_image_pil = None
64
  guidance_scale = 1.0
65
 
66
  # Create noise scheduler
@@ -138,30 +137,6 @@ def sample_ddpm_inference(text_prompt):
138
  uncond_input["image"] = torch.zeros_like(mask_tensor)
139
  cond_input["image"] = mask_tensor
140
 
141
- # Load the diffusion UNet (and assume it has been pretrained and saved)
142
- # unet = UNet(
143
- # image_channels=autoencoder_params["z_channels"], model_config=ldm_params
144
- # ).to(device)
145
- # ldm_checkpoint_path = os.path.join(
146
- # train_params["task_name"], train_params["ldm_ckpt_name"]
147
- # )
148
- # if os.path.exists(ldm_checkpoint_path):
149
- # checkpoint = torch.load(ldm_checkpoint_path, map_location=device)
150
- # unet.load_state_dict(checkpoint["model_state_dict"])
151
- # unet.eval()
152
-
153
- # Load VQVAE (assume pretrained and saved)
154
- # vae = VQVAE(
155
- # image_channels=dataset_params["image_channels"], model_config=autoencoder_params
156
- # ).to(device)
157
- # vae_checkpoint_path = os.path.join(
158
- # train_params["task_name"], train_params["vqvae_autoencoder_ckpt_name"]
159
- # )
160
- # if os.path.exists(vae_checkpoint_path):
161
- # checkpoint = torch.load(vae_checkpoint_path, map_location=device)
162
- # vae.load_state_dict(checkpoint["model_state_dict"])
163
- # vae.eval()
164
-
165
  # Determine latent shape from VQVAE: (batch, z_channels, H_lat, W_lat)
166
  # For example, if image_size is 256 and there are 3 downsamplings, H_lat = 256 // 8 = 32.
167
  latent_size = dataset_params["image_size"] // (
@@ -212,37 +187,28 @@ css_str = """
212
  }
213
  """
214
 
215
- # with gr.Blocks(css=css_str) as demo:
216
- # gr.Markdown("<div class='title'>Conditioned Latent Diffusion with CelebA</div>")
217
- # gr.Markdown(
218
- # "<div class='description'>Enter a text prompt and (optionally) upload a mask image for conditioning; the generated image will update as the reverse diffusion progresses.</div>"
219
- # )
220
- # with gr.Row():
221
- # text_input = gr.Textbox(
222
- # label="Text Prompt",
223
- # lines=2,
224
- # placeholder="E.g., 'He is a man with brown hair.'",
225
- # )
226
- # mask_input = gr.Image(type="pil", label="Optional Mask for Conditioning")
227
- # generate_button = gr.Button("Generate Image")
228
- # output_image = gr.Image(label="Generated Image", type="pil")
229
-
230
- # # Adding stream=True allows Gradio to process the generator output
231
- # generate_button.click(
232
- # fn=generate_image,
233
- # inputs=[text_input, mask_input],
234
- # outputs=[output_image],
235
- # )
236
-
237
- demo = gr.Interface(
238
- sample_ddpm_inference,
239
- inputs=gr.Textbox(
240
- label="Text Prompt",
241
- lines=2,
242
- placeholder="E.g., 'He is a man with brown hair.'",
243
- ),
244
- outputs="image",
245
- )
246
 
247
  if __name__ == "__main__":
248
- demo.launch(share=True)
 
55
  print("Model and checkpoints loaded successfully!")
56
 
57
 
58
+ def sample_ddpm_inference(text_prompt, mask_image_pil):
59
  """
60
  Given a text prompt and (optionally) an image condition (as a PIL image),
61
  sample from the diffusion model and return a generated image (PIL image).
62
  """
 
63
  guidance_scale = 1.0
64
 
65
  # Create noise scheduler
 
137
  uncond_input["image"] = torch.zeros_like(mask_tensor)
138
  cond_input["image"] = mask_tensor
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  # Determine latent shape from VQVAE: (batch, z_channels, H_lat, W_lat)
141
  # For example, if image_size is 256 and there are 3 downsamplings, H_lat = 256 // 8 = 32.
142
  latent_size = dataset_params["image_size"] // (
 
187
  }
188
  """
189
 
190
+ with gr.Blocks(css=css_str) as demo:
191
+ gr.Markdown("<div class='title'>Conditioned Latent Diffusion with CelebA</div>")
192
+ gr.Markdown(
193
+ "<div class='description'>Enter a text prompt and (optionally) upload a mask image for conditioning; the generated image will update as the reverse diffusion progresses.</div>"
194
+ )
195
+
196
+ with gr.Row():
197
+ text_input = gr.Textbox(
198
+ label="Text Prompt",
199
+ lines=2,
200
+ placeholder="E.g., 'He is a man with brown hair.'",
201
+ )
202
+ mask_input = gr.Image(type="pil", label="Optional Mask for Conditioning")
203
+
204
+ generate_button = gr.Button("Generate Image")
205
+ output_image = gr.Image(label="Generated Image", type="pil")
206
+
207
+ generate_button.click(
208
+ fn=sample_ddpm_inference,
209
+ inputs=[text_input, mask_input],
210
+ outputs=[output_image],
211
+ )
 
 
 
 
 
 
 
 
 
212
 
213
  if __name__ == "__main__":
214
+ demo.launch(share=True)