jsu27 commited on
Commit
4c1d330
·
1 Parent(s): c0e61b8
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch as th
4
+ from imageio import imread
5
+ from skimage.transform import resize as imresize
6
+
7
+ from ema_pytorch import EMA
8
+ from decomp_diffusion.model_and_diffusion_util import *
9
+ from decomp_diffusion.diffusion.respace import SpacedDiffusion
10
+ from decomp_diffusion.gen_image import *
11
+ from download import download_model
12
+
13
+ # fix randomness
14
+ th.manual_seed(0)
15
+ np.random.seed(0)
16
+
17
+ import gradio as gr
18
+
19
+
20
+ def get_pil_im(im, resolution=64):
21
+ im = imresize(im, (resolution, resolution))[:, :, :3]
22
+ im = th.Tensor(im).permute(2, 0, 1)[None, :, :, :].contiguous().cuda()
23
+ return im
24
+
25
+
26
+ # generate image components and reconstruction
27
+ def gen_image_and_components(model, gd, im, num_components=4, sample_method='ddim', batch_size=1, image_size=64, device='cuda', num_images=1):
28
+ """Generate row of orig image, individual components, and reconstructed image"""
29
+ orig_img = get_pil_im(im, resolution=image_size)
30
+ latent = model.encode_latent(orig_img)
31
+ model_kwargs = {'latent': latent}
32
+
33
+ assert sample_method in ('ddpm', 'ddim')
34
+ sample_loop_func = gd.p_sample_loop if sample_method == 'ddpm' else gd.ddim_sample_loop
35
+ if sample_method == 'ddim':
36
+ model = gd._wrap_model(model)
37
+
38
+ # generate imgs
39
+ for i in range(num_images):
40
+ all_samples = [orig_img]
41
+ # individual components
42
+ for j in range(num_components):
43
+ model_kwargs['latent_index'] = j
44
+ sample = sample_loop_func(
45
+ model,
46
+ (batch_size, 3, image_size, image_size),
47
+ device=device,
48
+ clip_denoised=True,
49
+ progress=True,
50
+ model_kwargs=model_kwargs,
51
+ cond_fn=None,
52
+ )[:batch_size]
53
+
54
+ # save indiv comp
55
+ all_samples.append(sample)
56
+ # reconstruction
57
+ model_kwargs['latent_index'] = None
58
+ sample = sample_loop_func(
59
+ model,
60
+ (batch_size, 3, image_size, image_size),
61
+ device=device,
62
+ clip_denoised=True,
63
+ progress=True,
64
+ model_kwargs=model_kwargs,
65
+ cond_fn=None,
66
+ )[:batch_size]
67
+ # save indiv reconstruction
68
+ all_samples.append(sample)
69
+
70
+ samples = th.cat(all_samples, dim=0).cpu()
71
+ grid = make_grid(samples, nrow=samples.shape[0], padding=0)
72
+ return grid
73
+
74
+
75
+ def decompose_image(im):
76
+ sample_method = 'ddim'
77
+ result = gen_image_and_components(clevr_model, GD[sample_method], im, sample_method=sample_method, num_images=1)
78
+ return result.permute(1, 2, 0).numpy()
79
+
80
+
81
+ # load diffusion
82
+ GD = {} # diffusion objects for ddim and ddpm
83
+ diffusion_kwargs = diffusion_defaults()
84
+ gd = create_gaussian_diffusion(**diffusion_kwargs)
85
+ GD['ddpm'] = gd
86
+
87
+ # set up ddim sampling
88
+ desired_timesteps = 50
89
+ num_timesteps = diffusion_kwargs['steps']
90
+
91
+ spacing = num_timesteps // desired_timesteps
92
+ spaced_ts = list(range(0, num_timesteps + 1, spacing))
93
+ betas = get_named_beta_schedule(diffusion_kwargs['noise_schedule'], num_timesteps)
94
+ diffusion_kwargs['betas'] = betas
95
+ del diffusion_kwargs['steps'], diffusion_kwargs['noise_schedule']
96
+ gd = SpacedDiffusion(spaced_ts, rescale_timesteps=True, original_num_steps=num_timesteps, **diffusion_kwargs)
97
+
98
+ GD['ddim'] = gd
99
+
100
+
101
+ # !wget https://www.dropbox.com/s/bqpc3ymstz9m05z/clevr_model.pt
102
+ # load model
103
+
104
+ ckpt_path = download_model('clevr') # 'clevr_model.pt'
105
+
106
+ model_kwargs = unet_model_defaults()
107
+ # model parameters
108
+ model_kwargs.update(dict(
109
+ emb_dim=64,
110
+ enc_channels=128
111
+ ))
112
+ clevr_model = create_diffusion_model(**model_kwargs)
113
+ clevr_model.eval()
114
+
115
+ device = 'cuda'
116
+ clevr_model.to(device)
117
+
118
+ print(f'loading from {ckpt_path}')
119
+ checkpoint = th.load(ckpt_path, map_location='cpu')
120
+
121
+ clevr_model.load_state_dict(checkpoint)
122
+
123
+
124
+
125
+ img_input = gr.inputs.Image(type="numpy", label="Input")
126
+ img_output = gr.outputs.Image(type="numpy", label="Output")
127
+
128
+ gr.Interface(
129
+ decompose_image,
130
+ inputs=img_input,
131
+ outputs=img_output,
132
+ examples=[
133
+ os.path.join(os.path.dirname(__file__), "sample_images/clevr_im_10.png"),
134
+ os.path.join(os.path.dirname(__file__), "sample_images/clevr_im_25.png"),
135
+ ],
136
+
137
+ ).launch()
download.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ from typing import Optional
4
+
5
+ import requests
6
+ from tqdm.auto import tqdm
7
+
8
+ MODEL_PATHS = {
9
+ "clevr": "https://www.dropbox.com/s/bqpc3ymstz9m05z/clevr_model.pt",
10
+ "celebahq": ""
11
+ }
12
+
13
+ DATA_PATHS = {
14
+ "clevr": "",
15
+ "clevr_toy": ""
16
+ }
17
+
18
+ def download_model(
19
+ dataset: str,
20
+ cache_dir: Optional[str] = None,
21
+ chunk_size: int = 4096,
22
+ ) -> str:
23
+ if dataset not in MODEL_PATHS:
24
+ raise ValueError(
25
+ f"Unknown dataset name {dataset}. Known names are: {MODEL_PATHS.keys()}."
26
+ )
27
+ if cache_dir is None:
28
+ cache_dir = './'
29
+ url = MODEL_PATHS[dataset]
30
+ os.makedirs(cache_dir, exist_ok=True)
31
+ local_path = os.path.join(cache_dir, url.split("/")[-1])
32
+ if os.path.exists(local_path.replace('?dl=0', '')):
33
+ return local_path.replace('?dl=0', '')
34
+ headers = {'user-agent': 'Wget/1.16 (linux-gnu)'}
35
+ r = requests.get(url, stream=True, headers=headers)
36
+ size = int(r.headers.get("content-length", "0"))
37
+ with open(local_path, 'wb') as f:
38
+ pbar = tqdm(total=size, unit="iB", unit_scale=True)
39
+ for chunk in r.iter_content(chunk_size=chunk_size):
40
+ if chunk:
41
+ pbar.update(len(chunk))
42
+ f.write(chunk)
43
+ os.rename(local_path, local_path.replace('?dl=0', ''))
44
+ return local_path.replace('?dl=0', '')
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ git+https://github.com/jsu27/decomp_diffusion.git
2
+ scikit-image==0.19.2
sample_images/clevr_im_10.png ADDED
sample_images/clevr_im_25.png ADDED