Spaces:
Runtime error
Runtime error
reformat
Browse files
app.py
CHANGED
@@ -152,58 +152,37 @@ def combine_components_slice(model, gd, im1, im2, indices=None, sample_method='d
|
|
152 |
|
153 |
def decompose_image_demo(im, model):
|
154 |
sample_method = 'ddim'
|
155 |
-
result = gen_image_and_components(MODELS[model], GD[sample_method], im, sample_method=sample_method, num_images=1)
|
156 |
return result.permute(1, 2, 0).numpy()
|
157 |
|
158 |
|
159 |
def combine_images_demo(im1, im2, model):
|
160 |
sample_method = 'ddim'
|
161 |
-
result = combine_components_slice(MODELS[model], GD[sample_method], im1, im2, indices='1,0,1,0', sample_method=sample_method, num_images=1)
|
162 |
return result.permute(1, 2, 0).numpy()
|
163 |
|
164 |
|
|
|
|
|
165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
-
|
|
|
168 |
|
169 |
-
|
170 |
-
|
171 |
-
model_kwargs.update(dict(
|
172 |
-
emb_dim=64,
|
173 |
-
enc_channels=128
|
174 |
-
))
|
175 |
-
clevr_model = create_diffusion_model(**model_kwargs)
|
176 |
-
clevr_model.eval()
|
177 |
|
178 |
-
device = 'cuda' if th.cuda.is_available() else 'cpu'
|
179 |
-
clevr_model.to(device)
|
180 |
-
|
181 |
-
print(f'loading from {ckpt_path}')
|
182 |
-
checkpoint = th.load(ckpt_path, map_location='cpu')
|
183 |
-
|
184 |
-
clevr_model.load_state_dict(checkpoint)
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
ckpt_path = download_model('celebahq') # 'celeb_model.pt'
|
189 |
-
|
190 |
-
model_kwargs = unet_model_defaults()
|
191 |
-
# model parameters
|
192 |
-
model_kwargs.update(dict(
|
193 |
-
enc_channels=128
|
194 |
-
))
|
195 |
-
celeb_model = create_diffusion_model(**model_kwargs)
|
196 |
-
celeb_model.eval()
|
197 |
|
198 |
device = 'cuda' if th.cuda.is_available() else 'cpu'
|
199 |
-
celeb_model.to(device)
|
200 |
-
|
201 |
-
print(f'loading from {ckpt_path}')
|
202 |
-
checkpoint = th.load(ckpt_path, map_location='cpu')
|
203 |
-
|
204 |
-
celeb_model.load_state_dict(checkpoint)
|
205 |
-
|
206 |
|
|
|
|
|
207 |
|
208 |
MODELS = {
|
209 |
'CLEVR': clevr_model,
|
@@ -222,7 +201,7 @@ with gr.Blocks() as demo:
|
|
222 |
|
223 |
gr.Markdown(
|
224 |
"""<h4>Decomposition and reconstruction of images</h4>""")
|
225 |
-
with gr.Row()
|
226 |
with gr.Column():
|
227 |
with gr.Row():
|
228 |
decomp_input = gr.Image(type='numpy', label='Input')
|
@@ -230,19 +209,21 @@ with gr.Blocks() as demo:
|
|
230 |
decomp_model = gr.Radio(
|
231 |
['CLEVR', 'CelebA-HQ'], type="value", label='Model',
|
232 |
value='CLEVR')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
with gr.Column():
|
235 |
decomp_output = gr.Image(type='numpy')
|
236 |
decomp_button = gr.Button("Generate")
|
237 |
-
|
238 |
-
|
239 |
-
# image_examples = [os.path.join(os.path.dirname(__file__), 'sample_images/clevr_im_10.png'), 'CLEVR']
|
240 |
-
decomp_examples = [['sample_images/clevr_im_10.png', 'CLEVR'],
|
241 |
-
['sample_images/celebahq_im_15.jpg', 'CelebA-HQ']]
|
242 |
-
decomp_img_examples = gr.Examples(
|
243 |
-
examples=decomp_examples,
|
244 |
-
inputs=[decomp_input, decomp_model]
|
245 |
-
)
|
246 |
|
247 |
|
248 |
gr.Markdown(
|
@@ -260,20 +241,21 @@ with gr.Blocks() as demo:
|
|
260 |
comb_model = gr.Radio(
|
261 |
['CLEVR', 'CelebA-HQ'], type="value", label='Model',
|
262 |
value='CLEVR')
|
263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
|
265 |
|
266 |
with gr.Column(scale=1):
|
267 |
comb_output = gr.Image(type='numpy')
|
268 |
comb_button = gr.Button("Generate")
|
269 |
-
|
270 |
-
|
271 |
-
comb_examples = [['sample_images/clevr_im_10.png', 'sample_images/clevr_im_25.png', 'CLEVR'],
|
272 |
-
['sample_images/celebahq_im_15.jpg', 'sample_images/celebahq_im_21.jpg', 'CelebA-HQ']]
|
273 |
-
comb_img_examples = gr.Examples(
|
274 |
-
examples=comb_examples,
|
275 |
-
inputs=[comb_input1, comb_input2, comb_model]
|
276 |
-
)
|
277 |
|
278 |
decomp_button.click(decompose_image_demo,
|
279 |
inputs=[decomp_input, decomp_model],
|
|
|
152 |
|
153 |
def decompose_image_demo(im, model):
|
154 |
sample_method = 'ddim'
|
155 |
+
result = gen_image_and_components(MODELS[model], GD[sample_method], im, sample_method=sample_method, num_images=1, device=device)
|
156 |
return result.permute(1, 2, 0).numpy()
|
157 |
|
158 |
|
159 |
def combine_images_demo(im1, im2, model):
|
160 |
sample_method = 'ddim'
|
161 |
+
result = combine_components_slice(MODELS[model], GD[sample_method], im1, im2, indices='1,0,1,0', sample_method=sample_method, num_images=1, device=device)
|
162 |
return result.permute(1, 2, 0).numpy()
|
163 |
|
164 |
|
165 |
+
def load_model(dataset, extra_kwargs={}, device='cuda'):
|
166 |
+
ckpt_path = download_model(dataset)
|
167 |
|
168 |
+
model_kwargs = unet_model_defaults()
|
169 |
+
# model parameters
|
170 |
+
model_kwargs.update(extra_kwargs)
|
171 |
+
model = create_diffusion_model(**model_kwargs)
|
172 |
+
model.eval()
|
173 |
+
model.to(device)
|
174 |
|
175 |
+
print(f'loading from {ckpt_path}')
|
176 |
+
checkpoint = th.load(ckpt_path, map_location='cpu')
|
177 |
|
178 |
+
model.load_state_dict(checkpoint)
|
179 |
+
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
device = 'cuda' if th.cuda.is_available() else 'cpu'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
184 |
+
clevr_model = load_model('clevr', extra_kwargs=dict(embed_dim=64, enc_channels=128), device=device)
|
185 |
+
celeb_model = load_model('celebahq', extra_kwargs=dict(enc_channels=128), device=device)
|
186 |
|
187 |
MODELS = {
|
188 |
'CLEVR': clevr_model,
|
|
|
201 |
|
202 |
gr.Markdown(
|
203 |
"""<h4>Decomposition and reconstruction of images</h4>""")
|
204 |
+
with gr.Row():
|
205 |
with gr.Column():
|
206 |
with gr.Row():
|
207 |
decomp_input = gr.Image(type='numpy', label='Input')
|
|
|
209 |
decomp_model = gr.Radio(
|
210 |
['CLEVR', 'CelebA-HQ'], type="value", label='Model',
|
211 |
value='CLEVR')
|
212 |
+
|
213 |
+
with gr.Row():
|
214 |
+
|
215 |
+
# image_examples = [os.path.join(os.path.dirname(__file__), 'sample_images/clevr_im_10.png'), 'CLEVR']
|
216 |
+
decomp_examples = [['sample_images/clevr_im_10.png', 'CLEVR'],
|
217 |
+
['sample_images/celebahq_im_15.jpg', 'CelebA-HQ']]
|
218 |
+
decomp_img_examples = gr.Examples(
|
219 |
+
examples=decomp_examples,
|
220 |
+
inputs=[decomp_input, decomp_model]
|
221 |
+
)
|
222 |
|
223 |
with gr.Column():
|
224 |
decomp_output = gr.Image(type='numpy')
|
225 |
decomp_button = gr.Button("Generate")
|
226 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
|
228 |
|
229 |
gr.Markdown(
|
|
|
241 |
comb_model = gr.Radio(
|
242 |
['CLEVR', 'CelebA-HQ'], type="value", label='Model',
|
243 |
value='CLEVR')
|
244 |
+
|
245 |
+
with gr.Row():
|
246 |
+
|
247 |
+
comb_examples = [['sample_images/clevr_im_10.png', 'sample_images/clevr_im_25.png', 'CLEVR'],
|
248 |
+
['sample_images/celebahq_im_15.jpg', 'sample_images/celebahq_im_21.jpg', 'CelebA-HQ']]
|
249 |
+
comb_img_examples = gr.Examples(
|
250 |
+
examples=comb_examples,
|
251 |
+
inputs=[comb_input1, comb_input2, comb_model]
|
252 |
+
)
|
253 |
|
254 |
|
255 |
with gr.Column(scale=1):
|
256 |
comb_output = gr.Image(type='numpy')
|
257 |
comb_button = gr.Button("Generate")
|
258 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
|
260 |
decomp_button.click(decompose_image_demo,
|
261 |
inputs=[decomp_input, decomp_model],
|