Varshitha2317 commited on
Commit
a6131d4
·
verified ·
1 Parent(s): 9bb90a4

Create tdd_svd_scheduler.py

Browse files
Files changed (1) hide show
  1. tdd_svd_scheduler.py +472 -0
tdd_svd_scheduler.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.utils import BaseOutput, logging
23
+ from diffusers.utils.torch_utils import randn_tensor
24
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
25
+
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ @dataclass
31
+ class TDDSVDStochasticIterativeSchedulerOutput(BaseOutput):
32
+ """
33
+ Output class for the scheduler's `step` function.
34
+ Args:
35
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
36
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
37
+ denoising loop.
38
+ """
39
+
40
+ prev_sample: torch.FloatTensor
41
+
42
+
43
+ class TDDSVDStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
44
+ """
45
+ Multistep and onestep sampling for consistency models.
46
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
47
+ methods the library implements for all schedulers such as loading and saving.
48
+ Args:
49
+ num_train_timesteps (`int`, defaults to 40):
50
+ The number of diffusion steps to train the model.
51
+ sigma_min (`float`, defaults to 0.002):
52
+ Minimum noise magnitude in the sigma schedule. Defaults to 0.002 from the original implementation.
53
+ sigma_max (`float`, defaults to 80.0):
54
+ Maximum noise magnitude in the sigma schedule. Defaults to 80.0 from the original implementation.
55
+ sigma_data (`float`, defaults to 0.5):
56
+ The standard deviation of the data distribution from the EDM
57
+ [paper](https://huggingface.co/papers/2206.00364). Defaults to 0.5 from the original implementation.
58
+ s_noise (`float`, defaults to 1.0):
59
+ The amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000,
60
+ 1.011]. Defaults to 1.0 from the original implementation.
61
+ rho (`float`, defaults to 7.0):
62
+ The parameter for calculating the Karras sigma schedule from the EDM
63
+ [paper](https://huggingface.co/papers/2206.00364). Defaults to 7.0 from the original implementation.
64
+ clip_denoised (`bool`, defaults to `True`):
65
+ Whether to clip the denoised outputs to `(-1, 1)`.
66
+ timesteps (`List` or `np.ndarray` or `torch.Tensor`, *optional*):
67
+ An explicit timestep schedule that can be optionally specified. The timesteps are expected to be in
68
+ increasing order.
69
+ """
70
+
71
+ order = 1
72
+
73
+ @register_to_config
74
+ def __init__(
75
+ self,
76
+ num_train_timesteps: int = 40,
77
+ sigma_min: float = 0.002,
78
+ sigma_max: float = 80.0,
79
+ sigma_data: float = 0.5,
80
+ s_noise: float = 1.0,
81
+ rho: float = 7.0,
82
+ clip_denoised: bool = True,
83
+ eta: float = 0.3,
84
+ ):
85
+ # standard deviation of the initial noise distribution
86
+ self.init_noise_sigma = (sigma_max**2 + 1) ** 0.5
87
+ # self.init_noise_sigma = sigma_max
88
+
89
+ ramp = np.linspace(0, 1, num_train_timesteps)
90
+ sigmas = self._convert_to_karras(ramp)
91
+ sigmas = np.concatenate([sigmas, np.array([0])])
92
+ timesteps = self.sigma_to_t(sigmas)
93
+
94
+ # setable values
95
+ self.num_inference_steps = None
96
+ self.sigmas = torch.from_numpy(sigmas)
97
+ self.timesteps = torch.from_numpy(timesteps)
98
+ self.custom_timesteps = False
99
+ self.is_scale_input_called = False
100
+ self._step_index = None
101
+ self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
102
+
103
+ self.set_eta(eta)
104
+ self.original_timesteps = self.timesteps.clone()
105
+ self.original_sigmas = self.sigmas.clone()
106
+
107
+
108
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
109
+ if schedule_timesteps is None:
110
+ schedule_timesteps = self.timesteps
111
+
112
+ indices = (schedule_timesteps == timestep).nonzero()
113
+ return indices.item()
114
+
115
+ @property
116
+ def step_index(self):
117
+ """
118
+ The index counter for current timestep. It will increae 1 after each scheduler step.
119
+ """
120
+ return self._step_index
121
+
122
+ def scale_model_input(
123
+ self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
124
+ ) -> torch.FloatTensor:
125
+ """
126
+ Scales the consistency model input by `(sigma**2 + sigma_data**2) ** 0.5`.
127
+ Args:
128
+ sample (`torch.FloatTensor`):
129
+ The input sample.
130
+ timestep (`float` or `torch.FloatTensor`):
131
+ The current timestep in the diffusion chain.
132
+ Returns:
133
+ `torch.FloatTensor`:
134
+ A scaled input sample.
135
+ """
136
+ # Get sigma corresponding to timestep
137
+ if self.step_index is None:
138
+ self._init_step_index(timestep)
139
+
140
+ sigma = self.sigmas[self.step_index]
141
+ sample = sample / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
142
+
143
+ self.is_scale_input_called = True
144
+ return sample
145
+
146
+ # def _sigma_to_t(self, sigma, log_sigmas):
147
+ # # get log sigma
148
+ # log_sigma = np.log(np.maximum(sigma, 1e-10))
149
+
150
+ # # get distribution
151
+ # dists = log_sigma - log_sigmas[:, np.newaxis]
152
+
153
+ # # get sigmas range
154
+ # low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
155
+ # high_idx = low_idx + 1
156
+
157
+ # low = log_sigmas[low_idx]
158
+ # high = log_sigmas[high_idx]
159
+
160
+ # # interpolate sigmas
161
+ # w = (low - log_sigma) / (low - high)
162
+ # w = np.clip(w, 0, 1)
163
+
164
+ # # transform interpolation to time range
165
+ # t = (1 - w) * low_idx + w * high_idx
166
+ # t = t.reshape(sigma.shape)
167
+ # return t
168
+
169
+ def sigma_to_t(self, sigmas: Union[float, np.ndarray]):
170
+ """
171
+ Gets scaled timesteps from the Karras sigmas for input to the consistency model.
172
+ Args:
173
+ sigmas (`float` or `np.ndarray`):
174
+ A single Karras sigma or an array of Karras sigmas.
175
+ Returns:
176
+ `float` or `np.ndarray`:
177
+ A scaled input timestep or scaled input timestep array.
178
+ """
179
+ if not isinstance(sigmas, np.ndarray):
180
+ sigmas = np.array(sigmas, dtype=np.float64)
181
+
182
+ timesteps = 0.25 * np.log(sigmas + 1e-44)
183
+
184
+ return timesteps
185
+
186
+ def set_timesteps(
187
+ self,
188
+ num_inference_steps: Optional[int] = None,
189
+ device: Union[str, torch.device] = None,
190
+ timesteps: Optional[List[int]] = None,
191
+ ):
192
+ """
193
+ Sets the timesteps used for the diffusion chain (to be run before inference).
194
+ Args:
195
+ num_inference_steps (`int`):
196
+ The number of diffusion steps used when generating samples with a pre-trained model.
197
+ device (`str` or `torch.device`, *optional*):
198
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
199
+ timesteps (`List[int]`, *optional*):
200
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
201
+ timestep spacing strategy of equal spacing between timesteps is used. If `timesteps` is passed,
202
+ `num_inference_steps` must be `None`.
203
+ """
204
+ if num_inference_steps is None and timesteps is None:
205
+ raise ValueError(
206
+ "Exactly one of `num_inference_steps` or `timesteps` must be supplied."
207
+ )
208
+
209
+ if num_inference_steps is not None and timesteps is not None:
210
+ raise ValueError(
211
+ "Can only pass one of `num_inference_steps` or `timesteps`."
212
+ )
213
+
214
+ # Follow DDPMScheduler custom timesteps logic
215
+ if timesteps is not None:
216
+ for i in range(1, len(timesteps)):
217
+ if timesteps[i] >= timesteps[i - 1]:
218
+ raise ValueError("`timesteps` must be in descending order.")
219
+
220
+ if timesteps[0] >= self.config.num_train_timesteps:
221
+ raise ValueError(
222
+ f"`timesteps` must start before `self.config.train_timesteps`:"
223
+ f" {self.config.num_train_timesteps}."
224
+ )
225
+
226
+ timesteps = np.array(timesteps, dtype=np.int64)
227
+ self.custom_timesteps = True
228
+ else:
229
+ if num_inference_steps > self.config.num_train_timesteps:
230
+ raise ValueError(
231
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
232
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
233
+ f" maximal {self.config.num_train_timesteps} timesteps."
234
+ )
235
+
236
+ self.num_inference_steps = num_inference_steps
237
+
238
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
239
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().copy().astype(np.int64)
240
+ self.custom_timesteps = False
241
+
242
+ self.original_indices = timesteps
243
+ # Map timesteps to Karras sigmas directly for multistep sampling
244
+ # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L675
245
+ num_train_timesteps = self.config.num_train_timesteps
246
+ ramp = timesteps.copy()
247
+ ramp = ramp / (num_train_timesteps - 1)
248
+ sigmas = self._convert_to_karras(ramp)
249
+ timesteps = self.sigma_to_t(sigmas)
250
+
251
+ sigmas = np.concatenate([sigmas, [0]]).astype(np.float32)
252
+ self.sigmas = torch.from_numpy(sigmas).to(device=device)
253
+
254
+ if str(device).startswith("mps"):
255
+ # mps does not support float64
256
+ self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
257
+ else:
258
+ self.timesteps = torch.from_numpy(timesteps).to(device=device)
259
+
260
+ self._step_index = None
261
+ self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
262
+
263
+ # Modified _convert_to_karras implementation that takes in ramp as argument
264
+ def _convert_to_karras(self, ramp):
265
+ """Constructs the noise schedule of Karras et al. (2022)."""
266
+
267
+ sigma_min: float = self.config.sigma_min
268
+ sigma_max: float = self.config.sigma_max
269
+
270
+ rho = self.config.rho
271
+ min_inv_rho = sigma_min ** (1 / rho)
272
+ max_inv_rho = sigma_max ** (1 / rho)
273
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
274
+ return sigmas
275
+
276
+ def get_scalings(self, sigma):
277
+ sigma_data = self.config.sigma_data
278
+
279
+ c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
280
+ c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
281
+ return c_skip, c_out
282
+
283
+ def get_scalings_for_boundary_condition(self, sigma):
284
+ """
285
+ Gets the scalings used in the consistency model parameterization (from Appendix C of the
286
+ [paper](https://huggingface.co/papers/2303.01469)) to enforce boundary condition.
287
+ <Tip>
288
+ `epsilon` in the equations for `c_skip` and `c_out` is set to `sigma_min`.
289
+ </Tip>
290
+ Args:
291
+ sigma (`torch.FloatTensor`):
292
+ The current sigma in the Karras sigma schedule.
293
+ Returns:
294
+ `tuple`:
295
+ A two-element tuple where `c_skip` (which weights the current sample) is the first element and `c_out`
296
+ (which weights the consistency model output) is the second element.
297
+ """
298
+ sigma_min = self.config.sigma_min
299
+ sigma_data = self.config.sigma_data
300
+
301
+ c_skip = sigma_data**2 / ((sigma) ** 2 + sigma_data**2)
302
+ c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
303
+ return c_skip, c_out
304
+
305
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
306
+ def _init_step_index(self, timestep):
307
+ if isinstance(timestep, torch.Tensor):
308
+ timestep = timestep.to(self.timesteps.device)
309
+
310
+ index_candidates = (self.timesteps == timestep).nonzero()
311
+
312
+ # The sigma index that is taken for the **very** first `step`
313
+ # is always the second index (or the last index if there is only 1)
314
+ # This way we can ensure we don't accidentally skip a sigma in
315
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
316
+ if len(index_candidates) > 1:
317
+ step_index = index_candidates[1]
318
+ else:
319
+ step_index = index_candidates[0]
320
+
321
+ self._step_index = step_index.item()
322
+
323
+ def step(
324
+ self,
325
+ model_output: torch.FloatTensor,
326
+ timestep: Union[float, torch.FloatTensor],
327
+ sample: torch.FloatTensor,
328
+ generator: Optional[torch.Generator] = None,
329
+ return_dict: bool = True,
330
+ ) -> Union[TDDSVDStochasticIterativeSchedulerOutput, Tuple]:
331
+ """
332
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
333
+ process from the learned model outputs (most often the predicted noise).
334
+ Args:
335
+ model_output (`torch.FloatTensor`):
336
+ The direct output from the learned diffusion model.
337
+ timestep (`float`):
338
+ The current timestep in the diffusion chain.
339
+ sample (`torch.FloatTensor`):
340
+ A current instance of a sample created by the diffusion process.
341
+ generator (`torch.Generator`, *optional*):
342
+ A random number generator.
343
+ return_dict (`bool`, *optional*, defaults to `True`):
344
+ Whether or not to return a
345
+ [`~schedulers.scheduling_consistency_models.TDDSVDStochasticIterativeSchedulerOutput`] or `tuple`.
346
+ Returns:
347
+ [`~schedulers.scheduling_consistency_models.TDDSVDStochasticIterativeSchedulerOutput`] or `tuple`:
348
+ If return_dict is `True`,
349
+ [`~schedulers.scheduling_consistency_models.TDDSVDStochasticIterativeSchedulerOutput`] is returned,
350
+ otherwise a tuple is returned where the first element is the sample tensor.
351
+ """
352
+
353
+ if (
354
+ isinstance(timestep, int)
355
+ or isinstance(timestep, torch.IntTensor)
356
+ or isinstance(timestep, torch.LongTensor)
357
+ ):
358
+ raise ValueError(
359
+ (
360
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
361
+ f" `{self.__class__}.step()` is not supported. Make sure to pass"
362
+ " one of the `scheduler.timesteps` as a timestep."
363
+ ),
364
+ )
365
+
366
+ if not self.is_scale_input_called:
367
+ logger.warning(
368
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
369
+ "See `StableDiffusionPipeline` for a usage example."
370
+ )
371
+
372
+ sigma_min = self.config.sigma_min
373
+ sigma_max = self.config.sigma_max
374
+
375
+ if self.step_index is None:
376
+ self._init_step_index(timestep)
377
+
378
+ # sigma_next corresponds to next_t in original implementation
379
+ next_step_index = self.step_index + 1
380
+
381
+ sigma = self.sigmas[self.step_index]
382
+ if next_step_index < len(self.sigmas):
383
+ sigma_next = self.sigmas[next_step_index]
384
+ else:
385
+ # Set sigma_next to sigma_min
386
+ sigma_next = self.sigmas[-1]
387
+
388
+ # Get scalings for boundary conditions
389
+ c_skip, c_out = self.get_scalings_for_boundary_condition(sigma)
390
+
391
+ if next_step_index < len(self.original_indices):
392
+ next_step_original_index = self.original_indices[next_step_index]
393
+ step_s_original_index = int(next_step_original_index + self.eta * (self.config.num_train_timesteps - 1 - next_step_original_index))
394
+ sigma_s = self.original_sigmas[step_s_original_index]
395
+ else:
396
+ sigma_s = self.sigmas[-1]
397
+
398
+ # 1. Denoise model output using boundary conditions
399
+ denoised = c_out * model_output + c_skip * sample
400
+ if self.config.clip_denoised:
401
+ denoised = denoised.clamp(-1, 1)
402
+
403
+ d = (sample - denoised) / sigma
404
+ sample_s = sample + d * (sigma_s - sigma)
405
+
406
+ # 2. Sample z ~ N(0, s_noise^2 * I)
407
+ # Noise is not used for onestep sampling.
408
+ if len(self.timesteps) > 1:
409
+ noise = randn_tensor(
410
+ model_output.shape,
411
+ dtype=model_output.dtype,
412
+ device=model_output.device,
413
+ generator=generator,
414
+ )
415
+ else:
416
+ noise = torch.zeros_like(model_output)
417
+ z = noise * self.config.s_noise
418
+
419
+ sigma_hat = sigma_next.clamp(min = 0, max = sigma_max)
420
+ # sigma_hat = sigma_next.clamp(min = sigma_min, max = sigma_max)
421
+
422
+ # print("denoise currently")
423
+ # print(sigma_hat)
424
+
425
+ # origin
426
+ # prev_sample = denoised + z * sigma_hat
427
+ prev_sample = sample_s + z * (sigma_hat - sigma_s)
428
+
429
+ # upon completion increase step index by one
430
+ self._step_index += 1
431
+
432
+ if not return_dict:
433
+ return (prev_sample,)
434
+
435
+ return TDDSVDStochasticIterativeSchedulerOutput(prev_sample=prev_sample)
436
+
437
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
438
+ def add_noise(
439
+ self,
440
+ original_samples: torch.FloatTensor,
441
+ noise: torch.FloatTensor,
442
+ timesteps: torch.FloatTensor,
443
+ ) -> torch.FloatTensor:
444
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
445
+ sigmas = self.sigmas.to(
446
+ device=original_samples.device, dtype=original_samples.dtype
447
+ )
448
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
449
+ # mps does not support float64
450
+ schedule_timesteps = self.timesteps.to(
451
+ original_samples.device, dtype=torch.float32
452
+ )
453
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
454
+ else:
455
+ schedule_timesteps = self.timesteps.to(original_samples.device)
456
+ timesteps = timesteps.to(original_samples.device)
457
+
458
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
459
+
460
+ sigma = sigmas[step_indices].flatten()
461
+ while len(sigma.shape) < len(original_samples.shape):
462
+ sigma = sigma.unsqueeze(-1)
463
+
464
+ noisy_samples = original_samples + noise * sigma
465
+ return noisy_samples
466
+
467
+ def __len__(self):
468
+ return self.config.num_train_timesteps
469
+
470
+ def set_eta(self, eta: float):
471
+ assert 0.0 <= eta <= 1.0
472
+ self.eta = eta