Mariam-Elz commited on
Commit
dce5fa6
·
verified ·
1 Parent(s): 472c525

Upload imagedream/ldm/models/diffusion/ddim.py with huggingface_hub

Browse files
imagedream/ldm/models/diffusion/ddim.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from ...modules.diffusionmodules.util import (
9
+ make_ddim_sampling_parameters,
10
+ make_ddim_timesteps,
11
+ noise_like,
12
+ extract_into_tensor,
13
+ )
14
+
15
+
16
+ class DDIMSampler(object):
17
+ def __init__(self, model, schedule="linear", **kwargs):
18
+ super().__init__()
19
+ self.model = model
20
+ self.ddpm_num_timesteps = model.num_timesteps
21
+ self.schedule = schedule
22
+
23
+ def register_buffer(self, name, attr):
24
+ if type(attr) == torch.Tensor:
25
+ if attr.device != torch.device("cuda"):
26
+ attr = attr.to(torch.device("cuda"))
27
+ setattr(self, name, attr)
28
+
29
+ def make_schedule(
30
+ self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
31
+ ):
32
+ self.ddim_timesteps = make_ddim_timesteps(
33
+ ddim_discr_method=ddim_discretize,
34
+ num_ddim_timesteps=ddim_num_steps,
35
+ num_ddpm_timesteps=self.ddpm_num_timesteps,
36
+ verbose=verbose,
37
+ )
38
+ alphas_cumprod = self.model.alphas_cumprod
39
+ assert (
40
+ alphas_cumprod.shape[0] == self.ddpm_num_timesteps
41
+ ), "alphas have to be defined for each timestep"
42
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
43
+
44
+ self.register_buffer("betas", to_torch(self.model.betas))
45
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
46
+ self.register_buffer(
47
+ "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
48
+ )
49
+
50
+ # calculations for diffusion q(x_t | x_{t-1}) and others
51
+ self.register_buffer(
52
+ "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
53
+ )
54
+ self.register_buffer(
55
+ "sqrt_one_minus_alphas_cumprod",
56
+ to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
57
+ )
58
+ self.register_buffer(
59
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
60
+ )
61
+ self.register_buffer(
62
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
63
+ )
64
+ self.register_buffer(
65
+ "sqrt_recipm1_alphas_cumprod",
66
+ to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
67
+ )
68
+
69
+ # ddim sampling parameters
70
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
71
+ alphacums=alphas_cumprod.cpu(),
72
+ ddim_timesteps=self.ddim_timesteps,
73
+ eta=ddim_eta,
74
+ verbose=verbose,
75
+ )
76
+ self.register_buffer("ddim_sigmas", ddim_sigmas)
77
+ self.register_buffer("ddim_alphas", ddim_alphas)
78
+ self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
79
+ self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
80
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
81
+ (1 - self.alphas_cumprod_prev)
82
+ / (1 - self.alphas_cumprod)
83
+ * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
84
+ )
85
+ self.register_buffer(
86
+ "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
87
+ )
88
+
89
+ @torch.no_grad()
90
+ def sample(
91
+ self,
92
+ S,
93
+ batch_size,
94
+ shape,
95
+ conditioning=None,
96
+ callback=None,
97
+ normals_sequence=None,
98
+ img_callback=None,
99
+ quantize_x0=False,
100
+ eta=0.0,
101
+ mask=None,
102
+ x0=None,
103
+ temperature=1.0,
104
+ noise_dropout=0.0,
105
+ score_corrector=None,
106
+ corrector_kwargs=None,
107
+ verbose=True,
108
+ x_T=None,
109
+ log_every_t=100,
110
+ unconditional_guidance_scale=1.0,
111
+ unconditional_conditioning=None,
112
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
113
+ **kwargs,
114
+ ):
115
+ if conditioning is not None:
116
+ if isinstance(conditioning, dict):
117
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
118
+ if cbs != batch_size:
119
+ print(
120
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
121
+ )
122
+ else:
123
+ if conditioning.shape[0] != batch_size:
124
+ print(
125
+ f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
126
+ )
127
+
128
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
129
+ # sampling
130
+ C, H, W = shape
131
+ size = (batch_size, C, H, W)
132
+
133
+ samples, intermediates = self.ddim_sampling(
134
+ conditioning,
135
+ size,
136
+ callback=callback,
137
+ img_callback=img_callback,
138
+ quantize_denoised=quantize_x0,
139
+ mask=mask,
140
+ x0=x0,
141
+ ddim_use_original_steps=False,
142
+ noise_dropout=noise_dropout,
143
+ temperature=temperature,
144
+ score_corrector=score_corrector,
145
+ corrector_kwargs=corrector_kwargs,
146
+ x_T=x_T,
147
+ log_every_t=log_every_t,
148
+ unconditional_guidance_scale=unconditional_guidance_scale,
149
+ unconditional_conditioning=unconditional_conditioning,
150
+ **kwargs,
151
+ )
152
+ return samples, intermediates
153
+
154
+ @torch.no_grad()
155
+ def ddim_sampling(
156
+ self,
157
+ cond,
158
+ shape,
159
+ x_T=None,
160
+ ddim_use_original_steps=False,
161
+ callback=None,
162
+ timesteps=None,
163
+ quantize_denoised=False,
164
+ mask=None,
165
+ x0=None,
166
+ img_callback=None,
167
+ log_every_t=100,
168
+ temperature=1.0,
169
+ noise_dropout=0.0,
170
+ score_corrector=None,
171
+ corrector_kwargs=None,
172
+ unconditional_guidance_scale=1.0,
173
+ unconditional_conditioning=None,
174
+ **kwargs,
175
+ ):
176
+ """
177
+ when inference time: all values of parameter
178
+ cond.keys(): dict_keys(['context', 'camera', 'num_frames', 'ip', 'ip_img'])
179
+ shape: (5, 4, 32, 32)
180
+ x_T: None
181
+ ddim_use_original_steps: False
182
+ timesteps: None
183
+ callback: None
184
+ quantize_denoised: False
185
+ mask: None
186
+ image_callback: None
187
+ log_every_t: 100
188
+ temperature: 1.0
189
+ noise_dropout: 0.0
190
+ score_corrector: None
191
+ corrector_kwargs: None
192
+ unconditional_guidance_scale: 5
193
+ unconditional_conditioning.keys(): dict_keys(['context', 'camera', 'num_frames', 'ip', 'ip_img'])
194
+ kwargs: {}
195
+ """
196
+ device = self.model.betas.device
197
+ b = shape[0]
198
+ if x_T is None:
199
+ img = torch.randn(shape, device=device) # shape: torch.Size([5, 4, 32, 32]) mean: -0.00, std: 1.00, min: -3.64, max: 3.94
200
+ else:
201
+ img = x_T
202
+
203
+ if timesteps is None: # equal with set time step in hf
204
+ timesteps = (
205
+ self.ddpm_num_timesteps
206
+ if ddim_use_original_steps
207
+ else self.ddim_timesteps
208
+ )
209
+ elif timesteps is not None and not ddim_use_original_steps:
210
+ subset_end = (
211
+ int(
212
+ min(timesteps / self.ddim_timesteps.shape[0], 1)
213
+ * self.ddim_timesteps.shape[0]
214
+ )
215
+ - 1
216
+ )
217
+ timesteps = self.ddim_timesteps[:subset_end]
218
+
219
+ intermediates = {"x_inter": [img], "pred_x0": [img]}
220
+ time_range = ( # reversed timesteps
221
+ reversed(range(0, timesteps))
222
+ if ddim_use_original_steps
223
+ else np.flip(timesteps)
224
+ )
225
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
226
+ iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
227
+ for i, step in enumerate(iterator):
228
+ index = total_steps - i - 1
229
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
230
+
231
+ if mask is not None:
232
+ assert x0 is not None
233
+ img_orig = self.model.q_sample(
234
+ x0, ts
235
+ ) # TODO: deterministic forward pass?
236
+ img = img_orig * mask + (1.0 - mask) * img
237
+
238
+ outs = self.p_sample_ddim(
239
+ img,
240
+ cond,
241
+ ts,
242
+ index=index,
243
+ use_original_steps=ddim_use_original_steps,
244
+ quantize_denoised=quantize_denoised,
245
+ temperature=temperature,
246
+ noise_dropout=noise_dropout,
247
+ score_corrector=score_corrector,
248
+ corrector_kwargs=corrector_kwargs,
249
+ unconditional_guidance_scale=unconditional_guidance_scale,
250
+ unconditional_conditioning=unconditional_conditioning,
251
+ **kwargs,
252
+ )
253
+ img, pred_x0 = outs
254
+ if callback:
255
+ callback(i)
256
+ if img_callback:
257
+ img_callback(pred_x0, i)
258
+
259
+ if index % log_every_t == 0 or index == total_steps - 1:
260
+ intermediates["x_inter"].append(img)
261
+ intermediates["pred_x0"].append(pred_x0)
262
+
263
+ return img, intermediates
264
+
265
+ @torch.no_grad()
266
+ def p_sample_ddim(
267
+ self,
268
+ x,
269
+ c,
270
+ t,
271
+ index,
272
+ repeat_noise=False,
273
+ use_original_steps=False,
274
+ quantize_denoised=False,
275
+ temperature=1.0,
276
+ noise_dropout=0.0,
277
+ score_corrector=None,
278
+ corrector_kwargs=None,
279
+ unconditional_guidance_scale=1.0,
280
+ unconditional_conditioning=None,
281
+ dynamic_threshold=None,
282
+ **kwargs,
283
+ ):
284
+ b, *_, device = *x.shape, x.device
285
+
286
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
287
+ model_output = self.model.apply_model(x, t, c)
288
+ else:
289
+ x_in = torch.cat([x] * 2)
290
+ t_in = torch.cat([t] * 2)
291
+ if isinstance(c, dict):
292
+ assert isinstance(unconditional_conditioning, dict)
293
+ c_in = dict()
294
+ for k in c:
295
+ if isinstance(c[k], list):
296
+ c_in[k] = [
297
+ torch.cat([unconditional_conditioning[k][i], c[k][i]])
298
+ for i in range(len(c[k]))
299
+ ]
300
+ elif isinstance(c[k], torch.Tensor):
301
+ c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
302
+ else:
303
+ assert c[k] == unconditional_conditioning[k]
304
+ c_in[k] = c[k]
305
+ elif isinstance(c, list):
306
+ c_in = list()
307
+ assert isinstance(unconditional_conditioning, list)
308
+ for i in range(len(c)):
309
+ c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
310
+ else:
311
+ c_in = torch.cat([unconditional_conditioning, c])
312
+ model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
313
+ model_output = model_uncond + unconditional_guidance_scale * (
314
+ model_t - model_uncond
315
+ )
316
+
317
+
318
+ if self.model.parameterization == "v":
319
+ print("using v!")
320
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
321
+ else:
322
+ e_t = model_output
323
+
324
+ if score_corrector is not None:
325
+ assert self.model.parameterization == "eps", "not implemented"
326
+ e_t = score_corrector.modify_score(
327
+ self.model, e_t, x, t, c, **corrector_kwargs
328
+ )
329
+
330
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
331
+ alphas_prev = (
332
+ self.model.alphas_cumprod_prev
333
+ if use_original_steps
334
+ else self.ddim_alphas_prev
335
+ )
336
+ sqrt_one_minus_alphas = (
337
+ self.model.sqrt_one_minus_alphas_cumprod
338
+ if use_original_steps
339
+ else self.ddim_sqrt_one_minus_alphas
340
+ )
341
+ sigmas = (
342
+ self.model.ddim_sigmas_for_original_num_steps
343
+ if use_original_steps
344
+ else self.ddim_sigmas
345
+ )
346
+ # select parameters corresponding to the currently considered timestep
347
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
348
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
349
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
350
+ sqrt_one_minus_at = torch.full(
351
+ (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
352
+ )
353
+
354
+ # current prediction for x_0
355
+ if self.model.parameterization != "v":
356
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
357
+ else:
358
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
359
+
360
+ if quantize_denoised:
361
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
362
+
363
+ if dynamic_threshold is not None:
364
+ raise NotImplementedError()
365
+
366
+ # direction pointing to x_t
367
+ dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
368
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
369
+ if noise_dropout > 0.0:
370
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
371
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
372
+ return x_prev, pred_x0
373
+
374
+ @torch.no_grad()
375
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
376
+ # fast, but does not allow for exact reconstruction
377
+ # t serves as an index to gather the correct alphas
378
+ if use_original_steps:
379
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
380
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
381
+ else:
382
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
383
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
384
+
385
+ if noise is None:
386
+ noise = torch.randn_like(x0)
387
+ return (
388
+ extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
389
+ + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
390
+ )
391
+
392
+ @torch.no_grad()
393
+ def decode(
394
+ self,
395
+ x_latent,
396
+ cond,
397
+ t_start,
398
+ unconditional_guidance_scale=1.0,
399
+ unconditional_conditioning=None,
400
+ use_original_steps=False,
401
+ **kwargs,
402
+ ):
403
+ timesteps = (
404
+ np.arange(self.ddpm_num_timesteps)
405
+ if use_original_steps
406
+ else self.ddim_timesteps
407
+ )
408
+ timesteps = timesteps[:t_start]
409
+
410
+ time_range = np.flip(timesteps)
411
+ total_steps = timesteps.shape[0]
412
+
413
+ iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
414
+ x_dec = x_latent
415
+ for i, step in enumerate(iterator):
416
+ index = total_steps - i - 1
417
+ ts = torch.full(
418
+ (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
419
+ )
420
+ x_dec, _ = self.p_sample_ddim(
421
+ x_dec,
422
+ cond,
423
+ ts,
424
+ index=index,
425
+ use_original_steps=use_original_steps,
426
+ unconditional_guidance_scale=unconditional_guidance_scale,
427
+ unconditional_conditioning=unconditional_conditioning,
428
+ **kwargs,
429
+ )
430
+ return x_dec