jsu27 commited on
Commit
1760b4e
·
1 Parent(s): 32fd213

updated for both decomp and combine

Browse files
Files changed (1) hide show
  1. app.py +109 -56
app.py CHANGED
@@ -98,46 +98,6 @@ gd = SpacedDiffusion(spaced_ts, rescale_timesteps=True, original_num_steps=num_t
98
  GD['ddim'] = gd
99
 
100
 
101
-
102
- # ckpt_path = download_model('clevr') # 'clevr_model.pt'
103
-
104
- # model_kwargs = unet_model_defaults()
105
- # # model parameters
106
- # model_kwargs.update(dict(
107
- # emb_dim=64,
108
- # enc_channels=128
109
- # ))
110
- # clevr_model = create_diffusion_model(**model_kwargs)
111
- # clevr_model.eval()
112
-
113
- # device = 'cuda' if th.cuda.is_available() else 'cpu'
114
- # clevr_model.to(device)
115
-
116
- # print(f'loading from {ckpt_path}')
117
- # checkpoint = th.load(ckpt_path, map_location='cpu')
118
-
119
- # clevr_model.load_state_dict(checkpoint)
120
-
121
-
122
-
123
- # img_input = gr.inputs.Image(type="numpy", label="Input")
124
- # img_output = gr.outputs.Image(type="numpy", label="Output")
125
-
126
- # gr.Interface(
127
- # decompose_image,
128
- # inputs=img_input,
129
- # outputs=img_output,
130
- # examples=[
131
- # "sample_images/clevr_im_10.png",
132
- # "sample_images/clevr_im_25.png",
133
- # ],
134
-
135
- # ).launch()
136
-
137
-
138
-
139
-
140
-
141
  def combine_components_slice(model, gd, im1, im2, indices=None, sample_method='ddim', device='cuda', num_images=4, model_kwargs={}, desc='', save_dir='', dataset='clevr', image_size=64):
142
  """Combine by adding components together
143
  """
@@ -188,13 +148,43 @@ def combine_components_slice(model, gd, im1, im2, indices=None, sample_method='d
188
 
189
  return sample[0].cpu()
190
 
191
- def combine_images(im1, im2):
 
 
 
 
 
 
 
 
192
  sample_method = 'ddim'
193
- result = combine_components_slice(celeb_model, GD[sample_method], im1, im2, indices='1,0,1,0', sample_method=sample_method, num_images=1, device=device)
194
  return result.permute(1, 2, 0).numpy()
195
 
196
 
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  ckpt_path = download_model('celebahq') # 'celeb_model.pt'
199
 
200
  model_kwargs = unet_model_defaults()
@@ -213,21 +203,84 @@ checkpoint = th.load(ckpt_path, map_location='cpu')
213
 
214
  celeb_model.load_state_dict(checkpoint)
215
 
216
- # Recombination
217
-
218
 
219
- img_input = gr.inputs.Image(type="numpy", label="Input")
220
- img_input2 = gr.inputs.Image(type="numpy", label="Input")
221
 
222
- img_output = gr.outputs.Image(type="numpy", label="Output")
 
 
 
223
 
224
- gr.Interface(
225
- combine_images,
226
- inputs=[img_input, img_input2],
227
- outputs=img_output,
228
- examples=[
229
- ["sample_images/celebahq_im_15.jpg",
230
- "sample_images/celebahq_im_21.jpg"]
231
- ]
232
- ).launch()
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  GD['ddim'] = gd
99
 
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  def combine_components_slice(model, gd, im1, im2, indices=None, sample_method='ddim', device='cuda', num_images=4, model_kwargs={}, desc='', save_dir='', dataset='clevr', image_size=64):
102
  """Combine by adding components together
103
  """
 
148
 
149
  return sample[0].cpu()
150
 
151
+
152
+
153
+ def decompose_image_demo(im, model):
154
+ sample_method = 'ddim'
155
+ result = gen_image_and_components(MODELS[model], GD[sample_method], im, sample_method=sample_method, num_images=1)
156
+ return result.permute(1, 2, 0).numpy()
157
+
158
+
159
+ def combine_images_demo(im1, im2, model):
160
  sample_method = 'ddim'
161
+ result = combine_components_slice(MODELS[model], GD[sample_method], im1, im2, indices='1,0,1,0', sample_method=sample_method, num_images=1)
162
  return result.permute(1, 2, 0).numpy()
163
 
164
 
165
 
166
+
167
+ ckpt_path = download_model('clevr') # 'clevr_model.pt'
168
+
169
+ model_kwargs = unet_model_defaults()
170
+ # model parameters
171
+ model_kwargs.update(dict(
172
+ emb_dim=64,
173
+ enc_channels=128
174
+ ))
175
+ clevr_model = create_diffusion_model(**model_kwargs)
176
+ clevr_model.eval()
177
+
178
+ device = 'cuda' if th.cuda.is_available() else 'cpu'
179
+ clevr_model.to(device)
180
+
181
+ print(f'loading from {ckpt_path}')
182
+ checkpoint = th.load(ckpt_path, map_location='cpu')
183
+
184
+ clevr_model.load_state_dict(checkpoint)
185
+
186
+
187
+
188
  ckpt_path = download_model('celebahq') # 'celeb_model.pt'
189
 
190
  model_kwargs = unet_model_defaults()
 
203
 
204
  celeb_model.load_state_dict(checkpoint)
205
 
 
 
206
 
 
 
207
 
208
+ MODELS = {
209
+ 'CLEVR': clevr_model,
210
+ 'CelebA-HQ': celeb_model
211
+ }
212
 
 
 
 
 
 
 
 
 
 
213
 
214
+ with gr.Blocks() as demo:
215
+ gr.Markdown(
216
+ """<h1 style="text-align: center;"><b>Unsupervised Compositional Image Decomposition with Diffusion Models
217
+ </b> - <a href="https://jsu27.github.io/decomp-diffusion-web/">Project Page</a></h1>""")
218
+
219
+ gr.Markdown(
220
+ """<p style="font-size: 18px;">We introduce Decomp Diffusion, an unsupervised approach that discovers compositional concepts from images, represented by diffusion models.
221
+ </p>""")
222
+
223
+ gr.Markdown(
224
+ """<h4>Decomposition and reconstruction of images</h4>""")
225
+ with gr.Row().style(equal_height=True):
226
+ with gr.Column():
227
+ with gr.Row():
228
+ decomp_input = gr.Image(type='numpy', label='Input')
229
+ with gr.Row():
230
+ decomp_model = gr.Radio(
231
+ ['CLEVR', 'CelebA-HQ'], type="value", label='Model',
232
+ value='CLEVR')
233
+
234
+ with gr.Column():
235
+ decomp_output = gr.Image(type='numpy')
236
+ decomp_button = gr.Button("Generate")
237
+ with gr.Row():
238
+
239
+ # image_examples = [os.path.join(os.path.dirname(__file__), 'sample_images/clevr_im_10.png'), 'CLEVR']
240
+ decomp_examples = [['sample_images/clevr_im_10.png', 'CLEVR'],
241
+ ['sample_images/celebahq_im_15.jpg', 'CelebA-HQ']]
242
+ decomp_img_examples = gr.Examples(
243
+ examples=decomp_examples,
244
+ inputs=[decomp_input, decomp_model]
245
+ )
246
+
247
+
248
+ gr.Markdown(
249
+ """<h4>Combination of images</h4>""")
250
+ with gr.Row().style(equal_height=True):
251
+ with gr.Column(scale=2):
252
+
253
+ with gr.Row():
254
+ with gr.Column():
255
+ comb_input1 = gr.Image(type='numpy', label='Input 1')
256
+ with gr.Column():
257
+ comb_input2 = gr.Image(type='numpy', label='Input 2')
258
+
259
+ with gr.Row():
260
+ comb_model = gr.Radio(
261
+ ['CLEVR', 'CelebA-HQ'], type="value", label='Model',
262
+ value='CLEVR')
263
+
264
+
265
+
266
+ with gr.Column(scale=1):
267
+ comb_output = gr.Image(type='numpy')
268
+ comb_button = gr.Button("Generate")
269
+ with gr.Row():
270
+
271
+ comb_examples = [['sample_images/clevr_im_10.png', 'sample_images/clevr_im_25.png', 'CLEVR'],
272
+ ['sample_images/celebahq_im_15.jpg', 'sample_images/celebahq_im_21.jpg', 'CelebA-HQ']]
273
+ comb_img_examples = gr.Examples(
274
+ examples=comb_examples,
275
+ inputs=[comb_input1, comb_input2, comb_model]
276
+ )
277
+
278
+ decomp_button.click(decompose_image_demo,
279
+ inputs=[decomp_input, decomp_model],
280
+ outputs=decomp_output)
281
+ comb_button.click(combine_images_demo,
282
+ inputs=[comb_input1, comb_input2, comb_model],
283
+ outputs=comb_output)
284
+
285
+
286
+ demo.launch()