Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -55,12 +55,11 @@ vae.eval()
|
|
55 |
print("Model and checkpoints loaded successfully!")
|
56 |
|
57 |
|
58 |
-
def sample_ddpm_inference(text_prompt):
|
59 |
"""
|
60 |
Given a text prompt and (optionally) an image condition (as a PIL image),
|
61 |
sample from the diffusion model and return a generated image (PIL image).
|
62 |
"""
|
63 |
-
mask_image_pil = None
|
64 |
guidance_scale = 1.0
|
65 |
|
66 |
# Create noise scheduler
|
@@ -138,30 +137,6 @@ def sample_ddpm_inference(text_prompt):
|
|
138 |
uncond_input["image"] = torch.zeros_like(mask_tensor)
|
139 |
cond_input["image"] = mask_tensor
|
140 |
|
141 |
-
# Load the diffusion UNet (and assume it has been pretrained and saved)
|
142 |
-
# unet = UNet(
|
143 |
-
# image_channels=autoencoder_params["z_channels"], model_config=ldm_params
|
144 |
-
# ).to(device)
|
145 |
-
# ldm_checkpoint_path = os.path.join(
|
146 |
-
# train_params["task_name"], train_params["ldm_ckpt_name"]
|
147 |
-
# )
|
148 |
-
# if os.path.exists(ldm_checkpoint_path):
|
149 |
-
# checkpoint = torch.load(ldm_checkpoint_path, map_location=device)
|
150 |
-
# unet.load_state_dict(checkpoint["model_state_dict"])
|
151 |
-
# unet.eval()
|
152 |
-
|
153 |
-
# Load VQVAE (assume pretrained and saved)
|
154 |
-
# vae = VQVAE(
|
155 |
-
# image_channels=dataset_params["image_channels"], model_config=autoencoder_params
|
156 |
-
# ).to(device)
|
157 |
-
# vae_checkpoint_path = os.path.join(
|
158 |
-
# train_params["task_name"], train_params["vqvae_autoencoder_ckpt_name"]
|
159 |
-
# )
|
160 |
-
# if os.path.exists(vae_checkpoint_path):
|
161 |
-
# checkpoint = torch.load(vae_checkpoint_path, map_location=device)
|
162 |
-
# vae.load_state_dict(checkpoint["model_state_dict"])
|
163 |
-
# vae.eval()
|
164 |
-
|
165 |
# Determine latent shape from VQVAE: (batch, z_channels, H_lat, W_lat)
|
166 |
# For example, if image_size is 256 and there are 3 downsamplings, H_lat = 256 // 8 = 32.
|
167 |
latent_size = dataset_params["image_size"] // (
|
@@ -212,37 +187,28 @@ css_str = """
|
|
212 |
}
|
213 |
"""
|
214 |
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
demo = gr.Interface(
|
238 |
-
sample_ddpm_inference,
|
239 |
-
inputs=gr.Textbox(
|
240 |
-
label="Text Prompt",
|
241 |
-
lines=2,
|
242 |
-
placeholder="E.g., 'He is a man with brown hair.'",
|
243 |
-
),
|
244 |
-
outputs="image",
|
245 |
-
)
|
246 |
|
247 |
if __name__ == "__main__":
|
248 |
-
demo.launch(share=True)
|
|
|
55 |
print("Model and checkpoints loaded successfully!")
|
56 |
|
57 |
|
58 |
+
def sample_ddpm_inference(text_prompt, mask_image_pil):
|
59 |
"""
|
60 |
Given a text prompt and (optionally) an image condition (as a PIL image),
|
61 |
sample from the diffusion model and return a generated image (PIL image).
|
62 |
"""
|
|
|
63 |
guidance_scale = 1.0
|
64 |
|
65 |
# Create noise scheduler
|
|
|
137 |
uncond_input["image"] = torch.zeros_like(mask_tensor)
|
138 |
cond_input["image"] = mask_tensor
|
139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
# Determine latent shape from VQVAE: (batch, z_channels, H_lat, W_lat)
|
141 |
# For example, if image_size is 256 and there are 3 downsamplings, H_lat = 256 // 8 = 32.
|
142 |
latent_size = dataset_params["image_size"] // (
|
|
|
187 |
}
|
188 |
"""
|
189 |
|
190 |
+
with gr.Blocks(css=css_str) as demo:
|
191 |
+
gr.Markdown("<div class='title'>Conditioned Latent Diffusion with CelebA</div>")
|
192 |
+
gr.Markdown(
|
193 |
+
"<div class='description'>Enter a text prompt and (optionally) upload a mask image for conditioning; the generated image will update as the reverse diffusion progresses.</div>"
|
194 |
+
)
|
195 |
+
|
196 |
+
with gr.Row():
|
197 |
+
text_input = gr.Textbox(
|
198 |
+
label="Text Prompt",
|
199 |
+
lines=2,
|
200 |
+
placeholder="E.g., 'He is a man with brown hair.'",
|
201 |
+
)
|
202 |
+
mask_input = gr.Image(type="pil", label="Optional Mask for Conditioning")
|
203 |
+
|
204 |
+
generate_button = gr.Button("Generate Image")
|
205 |
+
output_image = gr.Image(label="Generated Image", type="pil")
|
206 |
+
|
207 |
+
generate_button.click(
|
208 |
+
fn=sample_ddpm_inference,
|
209 |
+
inputs=[text_input, mask_input],
|
210 |
+
outputs=[output_image],
|
211 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
if __name__ == "__main__":
|
214 |
+
demo.launch(share=True)
|