RishabA commited on
Commit
036212e
·
verified ·
1 Parent(s): e90848d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -10
app.py CHANGED
@@ -1,9 +1,18 @@
1
  import torch
 
 
2
  import gradio as gr
3
  from model import (
4
  UNet,
5
  VQVAE,
6
- sample_ddpm_inference,
 
 
 
 
 
 
 
7
  )
8
  from huggingface_hub import hf_hub_download
9
  import json
@@ -46,16 +55,148 @@ vae.eval()
46
  print("Model and checkpoints loaded successfully!")
47
 
48
 
49
- def generate_image(text_prompt):
50
  """
51
- text_prompt: Text prompt provided by the user.
52
- mask_upload: Either a PIL image (uploaded) or None.
53
-
54
- This function returns a generator that yields an intermediate
55
- decoded image at every timestep from the diffusion process.
56
  """
57
- # sample_ddpm_inference is assumed to be a generator function (using yield)
58
- yield sample_ddpm_inference(unet, vae, text_prompt, None, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
 
61
  css_str = """
@@ -94,7 +235,7 @@ css_str = """
94
  # )
95
 
96
  demo = gr.Interface(
97
- generate_image,
98
  inputs=gr.Textbox(
99
  label="Text Prompt",
100
  lines=2,
 
1
  import torch
2
+ import torchvision.transforms as transforms
3
+ from torchvision.utils import make_grid
4
  import gradio as gr
5
  from model import (
6
  UNet,
7
  VQVAE,
8
+ LinearNoiseScheduler,
9
+ get_tokenizer_and_model,
10
+ get_text_representation,
11
+ dataset_params,
12
+ diffusion_params,
13
+ ldm_params,
14
+ autoencoder_params,
15
+ train_params,
16
  )
17
  from huggingface_hub import hf_hub_download
18
  import json
 
55
  print("Model and checkpoints loaded successfully!")
56
 
57
 
58
+ def sample_ddpm_inference(text_prompt):
59
  """
60
+ Given a text prompt and (optionally) an image condition (as a PIL image),
61
+ sample from the diffusion model and return a generated image (PIL image).
 
 
 
62
  """
63
+ mask_image_pil = None
64
+ guidance_scale = 1.0
65
+
66
+ # Create noise scheduler
67
+ scheduler = LinearNoiseScheduler(
68
+ num_timesteps=diffusion_params["num_timesteps"],
69
+ beta_start=diffusion_params["beta_start"],
70
+ beta_end=diffusion_params["beta_end"],
71
+ )
72
+ # Get conditioning config from ldm_params
73
+ condition_config = ldm_params.get("condition_config", None)
74
+ condition_types = (
75
+ condition_config.get("condition_types", [])
76
+ if condition_config is not None
77
+ else []
78
+ )
79
+
80
+ # Load text tokenizer/model for conditioning
81
+ text_model_type = condition_config["text_condition_config"]["text_embed_model"]
82
+ text_tokenizer, text_model = get_tokenizer_and_model(text_model_type, device=device)
83
+
84
+ # Get empty text representation for classifier-free guidance
85
+ empty_text_embed = get_text_representation([""], text_tokenizer, text_model, device)
86
+
87
+ # Get text representation of the input prompt
88
+ text_prompt_embed = get_text_representation(
89
+ [text_prompt], text_tokenizer, text_model, device
90
+ )
91
+
92
+ # Prepare image conditioning:
93
+ # If the user uploaded a mask image (should be a PIL image), convert it; otherwise, use zeros.
94
+ if "image" in condition_types:
95
+ if mask_image_pil is not None:
96
+ mask_transform = transforms.Compose(
97
+ [
98
+ transforms.Resize(
99
+ (
100
+ ldm_params["condition_config"]["image_condition_config"][
101
+ "image_condition_h"
102
+ ],
103
+ ldm_params["condition_config"]["image_condition_config"][
104
+ "image_condition_w"
105
+ ],
106
+ )
107
+ ),
108
+ transforms.ToTensor(),
109
+ ]
110
+ )
111
+ mask_tensor = (
112
+ mask_transform(mask_image_pil).unsqueeze(0).to(device)
113
+ ) # (1, channels, H, W)
114
+ else:
115
+ # Create a zero mask with the required number of channels (e.g. 18)
116
+ ic = ldm_params["condition_config"]["image_condition_config"][
117
+ "image_condition_input_channels"
118
+ ]
119
+ H = ldm_params["condition_config"]["image_condition_config"][
120
+ "image_condition_h"
121
+ ]
122
+ W = ldm_params["condition_config"]["image_condition_config"][
123
+ "image_condition_w"
124
+ ]
125
+ mask_tensor = torch.zeros((1, ic, H, W), device=device)
126
+ else:
127
+ mask_tensor = None
128
+
129
+ # Build conditioning dictionaries for classifier-free guidance:
130
+ # For unconditional, we use empty text and zero mask.
131
+ uncond_input = {}
132
+ cond_input = {}
133
+ if "text" in condition_types:
134
+ uncond_input["text"] = empty_text_embed
135
+ cond_input["text"] = text_prompt_embed
136
+ if "image" in condition_types:
137
+ # Use zeros for unconditioning, and the provided mask for conditioning.
138
+ uncond_input["image"] = torch.zeros_like(mask_tensor)
139
+ cond_input["image"] = mask_tensor
140
+
141
+ # Load the diffusion UNet (and assume it has been pretrained and saved)
142
+ # unet = UNet(
143
+ # image_channels=autoencoder_params["z_channels"], model_config=ldm_params
144
+ # ).to(device)
145
+ # ldm_checkpoint_path = os.path.join(
146
+ # train_params["task_name"], train_params["ldm_ckpt_name"]
147
+ # )
148
+ # if os.path.exists(ldm_checkpoint_path):
149
+ # checkpoint = torch.load(ldm_checkpoint_path, map_location=device)
150
+ # unet.load_state_dict(checkpoint["model_state_dict"])
151
+ # unet.eval()
152
+
153
+ # Load VQVAE (assume pretrained and saved)
154
+ # vae = VQVAE(
155
+ # image_channels=dataset_params["image_channels"], model_config=autoencoder_params
156
+ # ).to(device)
157
+ # vae_checkpoint_path = os.path.join(
158
+ # train_params["task_name"], train_params["vqvae_autoencoder_ckpt_name"]
159
+ # )
160
+ # if os.path.exists(vae_checkpoint_path):
161
+ # checkpoint = torch.load(vae_checkpoint_path, map_location=device)
162
+ # vae.load_state_dict(checkpoint["model_state_dict"])
163
+ # vae.eval()
164
+
165
+ # Determine latent shape from VQVAE: (batch, z_channels, H_lat, W_lat)
166
+ # For example, if image_size is 256 and there are 3 downsamplings, H_lat = 256 // 8 = 32.
167
+ latent_size = dataset_params["image_size"] // (
168
+ 2 ** sum(autoencoder_params["down_sample"])
169
+ )
170
+ batch = train_params["num_samples"]
171
+ z_channels = autoencoder_params["z_channels"]
172
+
173
+ # Sample initial latent noise
174
+ xt = torch.randn((batch, z_channels, latent_size, latent_size), device=device)
175
+
176
+ # Sampling loop (reverse diffusion)
177
+ T = diffusion_params["num_timesteps"]
178
+ for i in reversed(range(T)):
179
+ t = torch.full((batch,), i, dtype=torch.long, device=device)
180
+ # Get conditional noise prediction
181
+ noise_pred_cond = unet(xt, t, cond_input)
182
+ if guidance_scale > 1:
183
+ noise_pred_uncond = unet(xt, t, uncond_input)
184
+ noise_pred = noise_pred_uncond + guidance_scale * (
185
+ noise_pred_cond - noise_pred_uncond
186
+ )
187
+ else:
188
+ noise_pred = noise_pred_cond
189
+ xt, _ = scheduler.sample_prev_timestep(xt, noise_pred, t)
190
+
191
+ with torch.no_grad():
192
+ generated = vae.decode(xt)
193
+
194
+ generated = torch.clamp(generated, -1, 1)
195
+ generated = (generated + 1) / 2 # scale to [0,1]
196
+ grid = make_grid(generated, nrow=1)
197
+ pil_img = transforms.ToPILImage()(grid.cpu())
198
+
199
+ yield pil_img
200
 
201
 
202
  css_str = """
 
235
  # )
236
 
237
  demo = gr.Interface(
238
+ sample_ddpm_inference,
239
  inputs=gr.Textbox(
240
  label="Text Prompt",
241
  lines=2,