ProgramerSalar commited on
Commit
e661967
·
1 Parent(s): 32b7c72
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Yang Jin
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
diffusion_schedulers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .scheduling_cosine_ddpm import DDPMCosineScheduler
2
+ from .scheduling_flow_matching import PyramidFlowMatchEulerDiscreteScheduler
diffusion_schedulers/scheduling_cosine_ddpm.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.utils import BaseOutput
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
11
+
12
+
13
+ @dataclass
14
+ class DDPMSchedulerOutput(BaseOutput):
15
+ """
16
+ Output class for the scheduler's step function output.
17
+
18
+ Args:
19
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
20
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
21
+ denoising loop.
22
+ """
23
+
24
+ prev_sample: torch.Tensor
25
+
26
+
27
+ class DDPMCosineScheduler(SchedulerMixin, ConfigMixin):
28
+
29
+ @register_to_config
30
+ def __init__(
31
+ self,
32
+ scaler: float = 1.0,
33
+ s: float = 0.008,
34
+ ):
35
+ self.scaler = scaler
36
+ self.s = torch.tensor([s])
37
+ self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2
38
+
39
+ # standard deviation of the initial noise distribution
40
+ self.init_noise_sigma = 1.0
41
+
42
+ def _alpha_cumprod(self, t, device):
43
+ if self.scaler > 1:
44
+ t = 1 - (1 - t) ** self.scaler
45
+ elif self.scaler < 1:
46
+ t = t**self.scaler
47
+ alpha_cumprod = torch.cos(
48
+ (t + self.s.to(device)) / (1 + self.s.to(device)) * torch.pi * 0.5
49
+ ) ** 2 / self._init_alpha_cumprod.to(device)
50
+ return alpha_cumprod.clamp(0.0001, 0.9999)
51
+
52
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
53
+ """
54
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
55
+ current timestep.
56
+
57
+ Args:
58
+ sample (`torch.Tensor`): input sample
59
+ timestep (`int`, optional): current timestep
60
+
61
+ Returns:
62
+ `torch.Tensor`: scaled input sample
63
+ """
64
+ return sample
65
+
66
+ def set_timesteps(
67
+ self,
68
+ num_inference_steps: int = None,
69
+ timesteps: Optional[List[int]] = None,
70
+ device: Union[str, torch.device] = None,
71
+ ):
72
+ """
73
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
74
+
75
+ Args:
76
+ num_inference_steps (`Dict[float, int]`):
77
+ the number of diffusion steps used when generating samples with a pre-trained model. If passed, then
78
+ `timesteps` must be `None`.
79
+ device (`str` or `torch.device`, optional):
80
+ the device to which the timesteps are moved to. {2 / 3: 20, 0.0: 10}
81
+ """
82
+ if timesteps is None:
83
+ timesteps = torch.linspace(1.0, 0.0, num_inference_steps + 1, device=device)
84
+ if not isinstance(timesteps, torch.Tensor):
85
+ timesteps = torch.Tensor(timesteps).to(device)
86
+ self.timesteps = timesteps
87
+
88
+ def step(
89
+ self,
90
+ model_output: torch.Tensor,
91
+ timestep: int,
92
+ sample: torch.Tensor,
93
+ generator=None,
94
+ return_dict: bool = True,
95
+ ) -> Union[DDPMSchedulerOutput, Tuple]:
96
+ dtype = model_output.dtype
97
+ device = model_output.device
98
+ t = timestep
99
+
100
+ prev_t = self.previous_timestep(t)
101
+
102
+ alpha_cumprod = self._alpha_cumprod(t, device).view(t.size(0), *[1 for _ in sample.shape[1:]])
103
+ alpha_cumprod_prev = self._alpha_cumprod(prev_t, device).view(prev_t.size(0), *[1 for _ in sample.shape[1:]])
104
+ alpha = alpha_cumprod / alpha_cumprod_prev
105
+
106
+ mu = (1.0 / alpha).sqrt() * (sample - (1 - alpha) * model_output / (1 - alpha_cumprod).sqrt())
107
+
108
+ std_noise = randn_tensor(mu.shape, generator=generator, device=model_output.device, dtype=model_output.dtype)
109
+ std = ((1 - alpha) * (1.0 - alpha_cumprod_prev) / (1.0 - alpha_cumprod)).sqrt() * std_noise
110
+ pred = mu + std * (prev_t != 0).float().view(prev_t.size(0), *[1 for _ in sample.shape[1:]])
111
+
112
+ if not return_dict:
113
+ return (pred.to(dtype),)
114
+
115
+ return DDPMSchedulerOutput(prev_sample=pred.to(dtype))
116
+
117
+ def add_noise(
118
+ self,
119
+ original_samples: torch.Tensor,
120
+ noise: torch.Tensor,
121
+ timesteps: torch.Tensor,
122
+ ) -> torch.Tensor:
123
+ device = original_samples.device
124
+ dtype = original_samples.dtype
125
+ alpha_cumprod = self._alpha_cumprod(timesteps, device=device).view(
126
+ timesteps.size(0), *[1 for _ in original_samples.shape[1:]]
127
+ )
128
+ noisy_samples = alpha_cumprod.sqrt() * original_samples + (1 - alpha_cumprod).sqrt() * noise
129
+ return noisy_samples.to(dtype=dtype)
130
+
131
+ def __len__(self):
132
+ return self.config.num_train_timesteps
133
+
134
+ def previous_timestep(self, timestep):
135
+ index = (self.timesteps - timestep[0]).abs().argmin().item()
136
+ prev_t = self.timesteps[index + 1][None].expand(timestep.shape[0])
137
+ return prev_t
diffusion_schedulers/scheduling_flow_matching.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, Union, List
3
+ import math
4
+ import numpy as np
5
+ import torch
6
+
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.utils import BaseOutput, logging
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
11
+ from IPython import embed
12
+
13
+
14
+ @dataclass
15
+ class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
16
+ """
17
+ Output class for the scheduler's `step` function output.
18
+
19
+ Args:
20
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
21
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
22
+ denoising loop.
23
+ """
24
+
25
+ prev_sample: torch.FloatTensor
26
+
27
+
28
+ class PyramidFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
29
+ """
30
+ Euler scheduler.
31
+
32
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
33
+ methods the library implements for all schedulers such as loading and saving.
34
+
35
+ Args:
36
+ num_train_timesteps (`int`, defaults to 1000):
37
+ The number of diffusion steps to train the model.
38
+ timestep_spacing (`str`, defaults to `"linspace"`):
39
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
40
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
41
+ shift (`float`, defaults to 1.0):
42
+ The shift value for the timestep schedule.
43
+ """
44
+
45
+ _compatibles = []
46
+ order = 1
47
+
48
+ @register_to_config
49
+ def __init__(
50
+ self,
51
+ num_train_timesteps: int = 1000,
52
+ shift: float = 1.0, # Following Stable diffusion 3,
53
+ stages: int = 3,
54
+ stage_range: List = [0, 1/3, 2/3, 1],
55
+ gamma: float = 1/3,
56
+ ):
57
+
58
+ self.timestep_ratios = {} # The timestep ratio for each stage
59
+ self.timesteps_per_stage = {} # The detailed timesteps per stage
60
+ self.sigmas_per_stage = {}
61
+ self.start_sigmas = {}
62
+ self.end_sigmas = {}
63
+ self.ori_start_sigmas = {}
64
+
65
+ # self.init_sigmas()
66
+ self.init_sigmas_for_each_stage()
67
+ self.sigma_min = self.sigmas[-1].item()
68
+ self.sigma_max = self.sigmas[0].item()
69
+ self.gamma = gamma
70
+
71
+ def init_sigmas(self):
72
+ """
73
+ initialize the global timesteps and sigmas
74
+ """
75
+ num_train_timesteps = self.config.num_train_timesteps
76
+ shift = self.config.shift
77
+
78
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
79
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
80
+
81
+ sigmas = timesteps / num_train_timesteps
82
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
83
+
84
+ self.timesteps = sigmas * num_train_timesteps
85
+
86
+ self._step_index = None
87
+ self._begin_index = None
88
+
89
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
90
+
91
+ def init_sigmas_for_each_stage(self):
92
+ """
93
+ Init the timesteps for each stage
94
+ """
95
+ self.init_sigmas()
96
+
97
+ stage_distance = []
98
+ stages = self.config.stages
99
+ training_steps = self.config.num_train_timesteps
100
+ stage_range = self.config.stage_range
101
+
102
+ # Init the start and end point of each stage
103
+ for i_s in range(stages):
104
+ # To decide the start and ends point
105
+ start_indice = int(stage_range[i_s] * training_steps)
106
+ start_indice = max(start_indice, 0)
107
+ end_indice = int(stage_range[i_s+1] * training_steps)
108
+ end_indice = min(end_indice, training_steps)
109
+ start_sigma = self.sigmas[start_indice].item()
110
+ end_sigma = self.sigmas[end_indice].item() if end_indice < training_steps else 0.0
111
+ self.ori_start_sigmas[i_s] = start_sigma
112
+
113
+ if i_s != 0:
114
+ ori_sigma = 1 - start_sigma
115
+ gamma = self.config.gamma
116
+ corrected_sigma = (1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)) * ori_sigma
117
+ # corrected_sigma = 1 / (2 - ori_sigma) * ori_sigma
118
+ start_sigma = 1 - corrected_sigma
119
+
120
+ stage_distance.append(start_sigma - end_sigma)
121
+ self.start_sigmas[i_s] = start_sigma
122
+ self.end_sigmas[i_s] = end_sigma
123
+
124
+ # Determine the ratio of each stage according to flow length
125
+ tot_distance = sum(stage_distance)
126
+ for i_s in range(stages):
127
+ if i_s == 0:
128
+ start_ratio = 0.0
129
+ else:
130
+ start_ratio = sum(stage_distance[:i_s]) / tot_distance
131
+ if i_s == stages - 1:
132
+ end_ratio = 1.0
133
+ else:
134
+ end_ratio = sum(stage_distance[:i_s+1]) / tot_distance
135
+
136
+ self.timestep_ratios[i_s] = (start_ratio, end_ratio)
137
+
138
+ # Determine the timesteps and sigmas for each stage
139
+ for i_s in range(stages):
140
+ timestep_ratio = self.timestep_ratios[i_s]
141
+ timestep_max = self.timesteps[int(timestep_ratio[0] * training_steps)]
142
+ timestep_min = self.timesteps[min(int(timestep_ratio[1] * training_steps), training_steps - 1)]
143
+ timesteps = np.linspace(
144
+ timestep_max, timestep_min, training_steps + 1,
145
+ )
146
+ self.timesteps_per_stage[i_s] = torch.from_numpy(timesteps[:-1])
147
+ stage_sigmas = np.linspace(
148
+ 1, 0, training_steps + 1,
149
+ )
150
+ self.sigmas_per_stage[i_s] = torch.from_numpy(stage_sigmas[:-1])
151
+
152
+ @property
153
+ def step_index(self):
154
+ """
155
+ The index counter for current timestep. It will increase 1 after each scheduler step.
156
+ """
157
+ return self._step_index
158
+
159
+ @property
160
+ def begin_index(self):
161
+ """
162
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
163
+ """
164
+ return self._begin_index
165
+
166
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
167
+ def set_begin_index(self, begin_index: int = 0):
168
+ """
169
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
170
+
171
+ Args:
172
+ begin_index (`int`):
173
+ The begin index for the scheduler.
174
+ """
175
+ self._begin_index = begin_index
176
+
177
+ def _sigma_to_t(self, sigma):
178
+ return sigma * self.config.num_train_timesteps
179
+
180
+ def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Union[str, torch.device] = None):
181
+ """
182
+ Setting the timesteps and sigmas for each stage
183
+ """
184
+ self.num_inference_steps = num_inference_steps
185
+ training_steps = self.config.num_train_timesteps
186
+ self.init_sigmas()
187
+
188
+ stage_timesteps = self.timesteps_per_stage[stage_index]
189
+ timestep_max = stage_timesteps[0].item()
190
+ timestep_min = stage_timesteps[-1].item()
191
+
192
+ timesteps = np.linspace(
193
+ timestep_max, timestep_min, num_inference_steps,
194
+ )
195
+ self.timesteps = torch.from_numpy(timesteps).to(device=device)
196
+
197
+ stage_sigmas = self.sigmas_per_stage[stage_index]
198
+ sigma_max = stage_sigmas[0].item()
199
+ sigma_min = stage_sigmas[-1].item()
200
+
201
+ ratios = np.linspace(
202
+ sigma_max, sigma_min, num_inference_steps
203
+ )
204
+ sigmas = torch.from_numpy(ratios).to(device=device)
205
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
206
+
207
+ self._step_index = None
208
+
209
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
210
+ if schedule_timesteps is None:
211
+ schedule_timesteps = self.timesteps
212
+
213
+ indices = (schedule_timesteps == timestep).nonzero()
214
+
215
+ # The sigma index that is taken for the **very** first `step`
216
+ # is always the second index (or the last index if there is only 1)
217
+ # This way we can ensure we don't accidentally skip a sigma in
218
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
219
+ pos = 1 if len(indices) > 1 else 0
220
+
221
+ return indices[pos].item()
222
+
223
+ def _init_step_index(self, timestep):
224
+ if self.begin_index is None:
225
+ if isinstance(timestep, torch.Tensor):
226
+ timestep = timestep.to(self.timesteps.device)
227
+ self._step_index = self.index_for_timestep(timestep)
228
+ else:
229
+ self._step_index = self._begin_index
230
+
231
+ def step(
232
+ self,
233
+ model_output: torch.FloatTensor,
234
+ timestep: Union[float, torch.FloatTensor],
235
+ sample: torch.FloatTensor,
236
+ generator: Optional[torch.Generator] = None,
237
+ return_dict: bool = True,
238
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
239
+ """
240
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
241
+ process from the learned model outputs (most often the predicted noise).
242
+
243
+ Args:
244
+ model_output (`torch.FloatTensor`):
245
+ The direct output from learned diffusion model.
246
+ timestep (`float`):
247
+ The current discrete timestep in the diffusion chain.
248
+ sample (`torch.FloatTensor`):
249
+ A current instance of a sample created by the diffusion process.
250
+ generator (`torch.Generator`, *optional*):
251
+ A random number generator.
252
+ return_dict (`bool`):
253
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
254
+ tuple.
255
+
256
+ Returns:
257
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
258
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
259
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
260
+ """
261
+
262
+ if (
263
+ isinstance(timestep, int)
264
+ or isinstance(timestep, torch.IntTensor)
265
+ or isinstance(timestep, torch.LongTensor)
266
+ ):
267
+ raise ValueError(
268
+ (
269
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
270
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
271
+ " one of the `scheduler.timesteps` as a timestep."
272
+ ),
273
+ )
274
+
275
+ if self.step_index is None:
276
+ self._step_index = 0
277
+
278
+ # Upcast to avoid precision issues when computing prev_sample
279
+ sample = sample.to(torch.float32)
280
+
281
+ sigma = self.sigmas[self.step_index]
282
+ sigma_next = self.sigmas[self.step_index + 1]
283
+
284
+ prev_sample = sample + (sigma_next - sigma) * model_output
285
+
286
+ # Cast sample back to model compatible dtype
287
+ prev_sample = prev_sample.to(model_output.dtype)
288
+
289
+ # upon completion increase step index by one
290
+ self._step_index += 1
291
+
292
+ if not return_dict:
293
+ return (prev_sample,)
294
+
295
+ return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
296
+
297
+ def __len__(self):
298
+ return self.config.num_train_timesteps
pre-requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ wheel
2
+ torch
pyramid_dit/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .modeling_pyramid_mmdit import PyramidDiffusionMMDiT
2
+ from .pyramid_dit_for_video_gen_pipeline import PyramidDiTForVideoGeneration
3
+ from .modeling_text_encoder import SD3TextEncoderWithMask
pyramid_dit/modeling_embedding.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import math
7
+
8
+ from diffusers.models.activations import get_activation
9
+ from einops import rearrange
10
+
11
+
12
+ def get_1d_sincos_pos_embed(
13
+ embed_dim, num_frames, cls_token=False, extra_tokens=0,
14
+ ):
15
+ t = np.arange(num_frames, dtype=np.float32)
16
+ pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, t) # (T, D)
17
+ if cls_token and extra_tokens > 0:
18
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
19
+ return pos_embed
20
+
21
+
22
+ def get_2d_sincos_pos_embed(
23
+ embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
24
+ ):
25
+ """
26
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
27
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
28
+ """
29
+ if isinstance(grid_size, int):
30
+ grid_size = (grid_size, grid_size)
31
+
32
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
33
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
34
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
35
+ grid = np.stack(grid, axis=0)
36
+
37
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
38
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
39
+ if cls_token and extra_tokens > 0:
40
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
41
+ return pos_embed
42
+
43
+
44
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
45
+ if embed_dim % 2 != 0:
46
+ raise ValueError("embed_dim must be divisible by 2")
47
+
48
+ # use half of dimensions to encode grid_h
49
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
50
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
51
+
52
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
53
+ return emb
54
+
55
+
56
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
57
+ """
58
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
59
+ """
60
+ if embed_dim % 2 != 0:
61
+ raise ValueError("embed_dim must be divisible by 2")
62
+
63
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
64
+ omega /= embed_dim / 2.0
65
+ omega = 1.0 / 10000**omega # (D/2,)
66
+
67
+ pos = pos.reshape(-1) # (M,)
68
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
69
+
70
+ emb_sin = np.sin(out) # (M, D/2)
71
+ emb_cos = np.cos(out) # (M, D/2)
72
+
73
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
74
+ return emb
75
+
76
+
77
+ def get_timestep_embedding(
78
+ timesteps: torch.Tensor,
79
+ embedding_dim: int,
80
+ flip_sin_to_cos: bool = False,
81
+ downscale_freq_shift: float = 1,
82
+ scale: float = 1,
83
+ max_period: int = 10000,
84
+ ):
85
+ """
86
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
87
+ :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional.
88
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
89
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
90
+ """
91
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
92
+
93
+ half_dim = embedding_dim // 2
94
+ exponent = -math.log(max_period) * torch.arange(
95
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
96
+ )
97
+ exponent = exponent / (half_dim - downscale_freq_shift)
98
+
99
+ emb = torch.exp(exponent)
100
+ emb = timesteps[:, None].float() * emb[None, :]
101
+
102
+ # scale embeddings
103
+ emb = scale * emb
104
+
105
+ # concat sine and cosine embeddings
106
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
107
+
108
+ # flip sine and cosine embeddings
109
+ if flip_sin_to_cos:
110
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
111
+
112
+ # zero pad
113
+ if embedding_dim % 2 == 1:
114
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
115
+ return emb
116
+
117
+
118
+ class Timesteps(nn.Module):
119
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
120
+ super().__init__()
121
+ self.num_channels = num_channels
122
+ self.flip_sin_to_cos = flip_sin_to_cos
123
+ self.downscale_freq_shift = downscale_freq_shift
124
+
125
+ def forward(self, timesteps):
126
+ t_emb = get_timestep_embedding(
127
+ timesteps,
128
+ self.num_channels,
129
+ flip_sin_to_cos=self.flip_sin_to_cos,
130
+ downscale_freq_shift=self.downscale_freq_shift,
131
+ )
132
+ return t_emb
133
+
134
+
135
+ class TimestepEmbedding(nn.Module):
136
+ def __init__(
137
+ self,
138
+ in_channels: int,
139
+ time_embed_dim: int,
140
+ act_fn: str = "silu",
141
+ out_dim: int = None,
142
+ post_act_fn: Optional[str] = None,
143
+ sample_proj_bias=True,
144
+ ):
145
+ super().__init__()
146
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
147
+ self.act = get_activation(act_fn)
148
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, sample_proj_bias)
149
+
150
+ def forward(self, sample):
151
+ sample = self.linear_1(sample)
152
+ sample = self.act(sample)
153
+ sample = self.linear_2(sample)
154
+ return sample
155
+
156
+
157
+ class TextProjection(nn.Module):
158
+ def __init__(self, in_features, hidden_size, act_fn="silu"):
159
+ super().__init__()
160
+ self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
161
+ self.act_1 = get_activation(act_fn)
162
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
163
+
164
+ def forward(self, caption):
165
+ hidden_states = self.linear_1(caption)
166
+ hidden_states = self.act_1(hidden_states)
167
+ hidden_states = self.linear_2(hidden_states)
168
+ return hidden_states
169
+
170
+
171
+ class CombinedTimestepConditionEmbeddings(nn.Module):
172
+ def __init__(self, embedding_dim, pooled_projection_dim):
173
+ super().__init__()
174
+
175
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
176
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
177
+ self.text_embedder = TextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
178
+
179
+ def forward(self, timestep, pooled_projection):
180
+ timesteps_proj = self.time_proj(timestep)
181
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
182
+ pooled_projections = self.text_embedder(pooled_projection)
183
+ conditioning = timesteps_emb + pooled_projections
184
+ return conditioning
185
+
186
+
187
+ class CombinedTimestepEmbeddings(nn.Module):
188
+ def __init__(self, embedding_dim):
189
+ super().__init__()
190
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
191
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
192
+
193
+ def forward(self, timestep):
194
+ timesteps_proj = self.time_proj(timestep)
195
+ timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
196
+ return timesteps_emb
197
+
198
+
199
+ class PatchEmbed3D(nn.Module):
200
+ """Support the 3D Tensor input"""
201
+
202
+ def __init__(
203
+ self,
204
+ height=128,
205
+ width=128,
206
+ patch_size=2,
207
+ in_channels=16,
208
+ embed_dim=1536,
209
+ layer_norm=False,
210
+ bias=True,
211
+ interpolation_scale=1,
212
+ pos_embed_type="sincos",
213
+ temp_pos_embed_type='rope',
214
+ pos_embed_max_size=192, # For SD3 cropping
215
+ max_num_frames=64,
216
+ add_temp_pos_embed=False,
217
+ interp_condition_pos=False,
218
+ ):
219
+ super().__init__()
220
+
221
+ num_patches = (height // patch_size) * (width // patch_size)
222
+ self.layer_norm = layer_norm
223
+ self.pos_embed_max_size = pos_embed_max_size
224
+
225
+ self.proj = nn.Conv2d(
226
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
227
+ )
228
+ if layer_norm:
229
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
230
+ else:
231
+ self.norm = None
232
+
233
+ self.patch_size = patch_size
234
+ self.height, self.width = height // patch_size, width // patch_size
235
+ self.base_size = height // patch_size
236
+ self.interpolation_scale = interpolation_scale
237
+ self.add_temp_pos_embed = add_temp_pos_embed
238
+
239
+ # Calculate positional embeddings based on max size or default
240
+ if pos_embed_max_size:
241
+ grid_size = pos_embed_max_size
242
+ else:
243
+ grid_size = int(num_patches**0.5)
244
+
245
+ if pos_embed_type is None:
246
+ self.pos_embed = None
247
+
248
+ elif pos_embed_type == "sincos":
249
+ pos_embed = get_2d_sincos_pos_embed(
250
+ embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
251
+ )
252
+ persistent = True if pos_embed_max_size else False
253
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent)
254
+
255
+ if add_temp_pos_embed and temp_pos_embed_type == 'sincos':
256
+ time_pos_embed = get_1d_sincos_pos_embed(embed_dim, max_num_frames)
257
+ self.register_buffer("temp_pos_embed", torch.from_numpy(time_pos_embed).float().unsqueeze(0), persistent=True)
258
+
259
+ elif pos_embed_type == "rope":
260
+ print("Using the rotary position embedding")
261
+
262
+ else:
263
+ raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
264
+
265
+ self.pos_embed_type = pos_embed_type
266
+ self.temp_pos_embed_type = temp_pos_embed_type
267
+ self.interp_condition_pos = interp_condition_pos
268
+
269
+ def cropped_pos_embed(self, height, width, ori_height, ori_width):
270
+ """Crops positional embeddings for SD3 compatibility."""
271
+ if self.pos_embed_max_size is None:
272
+ raise ValueError("`pos_embed_max_size` must be set for cropping.")
273
+
274
+ height = height // self.patch_size
275
+ width = width // self.patch_size
276
+ ori_height = ori_height // self.patch_size
277
+ ori_width = ori_width // self.patch_size
278
+
279
+ assert ori_height >= height, "The ori_height needs >= height"
280
+ assert ori_width >= width, "The ori_width needs >= width"
281
+
282
+ if height > self.pos_embed_max_size:
283
+ raise ValueError(
284
+ f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
285
+ )
286
+ if width > self.pos_embed_max_size:
287
+ raise ValueError(
288
+ f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
289
+ )
290
+
291
+ if self.interp_condition_pos:
292
+ top = (self.pos_embed_max_size - ori_height) // 2
293
+ left = (self.pos_embed_max_size - ori_width) // 2
294
+ spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
295
+ spatial_pos_embed = spatial_pos_embed[:, top : top + ori_height, left : left + ori_width, :] # [b h w c]
296
+ if ori_height != height or ori_width != width:
297
+ spatial_pos_embed = spatial_pos_embed.permute(0, 3, 1, 2)
298
+ spatial_pos_embed = torch.nn.functional.interpolate(spatial_pos_embed, size=(height, width), mode='bilinear')
299
+ spatial_pos_embed = spatial_pos_embed.permute(0, 2, 3, 1)
300
+ else:
301
+ top = (self.pos_embed_max_size - height) // 2
302
+ left = (self.pos_embed_max_size - width) // 2
303
+ spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
304
+ spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
305
+
306
+ spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
307
+
308
+ return spatial_pos_embed
309
+
310
+ def forward_func(self, latent, time_index=0, ori_height=None, ori_width=None):
311
+ if self.pos_embed_max_size is not None:
312
+ height, width = latent.shape[-2:]
313
+ else:
314
+ height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
315
+
316
+ bs = latent.shape[0]
317
+ temp = latent.shape[2]
318
+
319
+ latent = rearrange(latent, 'b c t h w -> (b t) c h w')
320
+ latent = self.proj(latent)
321
+ latent = latent.flatten(2).transpose(1, 2) # (BT)CHW -> (BT)NC
322
+
323
+ if self.layer_norm:
324
+ latent = self.norm(latent)
325
+
326
+ if self.pos_embed_type == 'sincos':
327
+ # Spatial position embedding, Interpolate or crop positional embeddings as needed
328
+ if self.pos_embed_max_size:
329
+ pos_embed = self.cropped_pos_embed(height, width, ori_height, ori_width)
330
+ else:
331
+ raise NotImplementedError("Not implemented sincos pos embed without sd3 max pos crop")
332
+ if self.height != height or self.width != width:
333
+ pos_embed = get_2d_sincos_pos_embed(
334
+ embed_dim=self.pos_embed.shape[-1],
335
+ grid_size=(height, width),
336
+ base_size=self.base_size,
337
+ interpolation_scale=self.interpolation_scale,
338
+ )
339
+ pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
340
+ else:
341
+ pos_embed = self.pos_embed
342
+
343
+ if self.add_temp_pos_embed and self.temp_pos_embed_type == 'sincos':
344
+ latent_dtype = latent.dtype
345
+ latent = latent + pos_embed
346
+ latent = rearrange(latent, '(b t) n c -> (b n) t c', t=temp)
347
+ latent = latent + self.temp_pos_embed[:, time_index:time_index + temp, :]
348
+ latent = latent.to(latent_dtype)
349
+ latent = rearrange(latent, '(b n) t c -> b t n c', b=bs)
350
+ else:
351
+ latent = (latent + pos_embed).to(latent.dtype)
352
+ latent = rearrange(latent, '(b t) n c -> b t n c', b=bs, t=temp)
353
+
354
+ else:
355
+ assert self.pos_embed_type == "rope", "Only supporting the sincos and rope embedding"
356
+ latent = rearrange(latent, '(b t) n c -> b t n c', b=bs, t=temp)
357
+
358
+ return latent
359
+
360
+ def forward(self, latent):
361
+ """
362
+ Arguments:
363
+ past_condition_latents (Torch.FloatTensor): The past latent during the generation
364
+ flatten_input (bool): True indicate flatten the latent into 1D sequence
365
+ """
366
+
367
+ if isinstance(latent, list):
368
+ output_list = []
369
+
370
+ for latent_ in latent:
371
+ if not isinstance(latent_, list):
372
+ latent_ = [latent_]
373
+
374
+ output_latent = []
375
+ time_index = 0
376
+ ori_height, ori_width = latent_[-1].shape[-2:]
377
+ for each_latent in latent_:
378
+ hidden_state = self.forward_func(each_latent, time_index=time_index, ori_height=ori_height, ori_width=ori_width)
379
+ time_index += each_latent.shape[2]
380
+ hidden_state = rearrange(hidden_state, "b t n c -> b (t n) c")
381
+ output_latent.append(hidden_state)
382
+
383
+ output_latent = torch.cat(output_latent, dim=1)
384
+ output_list.append(output_latent)
385
+
386
+ return output_list
387
+ else:
388
+ hidden_states = self.forward_func(latent)
389
+ hidden_states = rearrange(hidden_states, "b t n c -> b (t n) c")
390
+ return hidden_states
pyramid_dit/modeling_mmdit_block.py ADDED
@@ -0,0 +1,672 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Tuple, List
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
7
+
8
+ try:
9
+ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
10
+ from flash_attn.bert_padding import pad_input, unpad_input, index_first_axis
11
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
12
+ except:
13
+ flash_attn_func = None
14
+ flash_attn_qkvpacked_func = None
15
+ flash_attn_varlen_func = None
16
+ print("Please install flash attention")
17
+
18
+ from trainer_misc import (
19
+ is_sequence_parallel_initialized,
20
+ get_sequence_parallel_group,
21
+ get_sequence_parallel_world_size,
22
+ all_to_all,
23
+ )
24
+
25
+ from .modeling_normalization import AdaLayerNormZero, AdaLayerNormContinuous, RMSNorm
26
+
27
+
28
+ class FeedForward(nn.Module):
29
+ r"""
30
+ A feed-forward layer.
31
+
32
+ Parameters:
33
+ dim (`int`): The number of channels in the input.
34
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
35
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
36
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
37
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
38
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
39
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
40
+ """
41
+ def __init__(
42
+ self,
43
+ dim: int,
44
+ dim_out: Optional[int] = None,
45
+ mult: int = 4,
46
+ dropout: float = 0.0,
47
+ activation_fn: str = "geglu",
48
+ final_dropout: bool = False,
49
+ inner_dim=None,
50
+ bias: bool = True,
51
+ ):
52
+ super().__init__()
53
+ if inner_dim is None:
54
+ inner_dim = int(dim * mult)
55
+ dim_out = dim_out if dim_out is not None else dim
56
+
57
+ if activation_fn == "gelu":
58
+ act_fn = GELU(dim, inner_dim, bias=bias)
59
+ if activation_fn == "gelu-approximate":
60
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
61
+ elif activation_fn == "geglu":
62
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
63
+ elif activation_fn == "geglu-approximate":
64
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
65
+
66
+ self.net = nn.ModuleList([])
67
+ # project in
68
+ self.net.append(act_fn)
69
+ # project dropout
70
+ self.net.append(nn.Dropout(dropout))
71
+ # project out
72
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
73
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
74
+ if final_dropout:
75
+ self.net.append(nn.Dropout(dropout))
76
+
77
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
78
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
79
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
80
+ deprecate("scale", "1.0.0", deprecation_message)
81
+ for module in self.net:
82
+ hidden_states = module(hidden_states)
83
+ return hidden_states
84
+
85
+
86
+ class VarlenFlashSelfAttentionWithT5Mask:
87
+
88
+ def __init__(self):
89
+ pass
90
+
91
+ def apply_rope(self, xq, xk, freqs_cis):
92
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
93
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
94
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
95
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
96
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
97
+
98
+ def __call__(
99
+ self, query, key, value, encoder_query, encoder_key, encoder_value,
100
+ heads, scale, hidden_length=None, image_rotary_emb=None, encoder_attention_mask=None,
101
+ ):
102
+ assert encoder_attention_mask is not None, "The encoder-hidden mask needed to be set"
103
+
104
+ batch_size = query.shape[0]
105
+ output_hidden = torch.zeros_like(query)
106
+ output_encoder_hidden = torch.zeros_like(encoder_query)
107
+ encoder_length = encoder_query.shape[1]
108
+
109
+ qkv_list = []
110
+ num_stages = len(hidden_length)
111
+
112
+ encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
113
+ qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
114
+
115
+ i_sum = 0
116
+ for i_p, length in enumerate(hidden_length):
117
+ encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
118
+ qkv_tokens = qkv[:, i_sum:i_sum+length]
119
+ concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) # [bs, tot_seq, 3, nhead, dim]
120
+
121
+ if image_rotary_emb is not None:
122
+ concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = self.apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
123
+
124
+ indices = encoder_attention_mask[i_p]['indices']
125
+ qkv_list.append(index_first_axis(rearrange(concat_qkv_tokens, "b s ... -> (b s) ..."), indices))
126
+ i_sum += length
127
+
128
+ token_lengths = [x_.shape[0] for x_ in qkv_list]
129
+ qkv = torch.cat(qkv_list, dim=0)
130
+ query, key, value = qkv.unbind(1)
131
+
132
+ cu_seqlens = torch.cat([x_['seqlens_in_batch'] for x_ in encoder_attention_mask], dim=0)
133
+ max_seqlen_q = cu_seqlens.max().item()
134
+ max_seqlen_k = max_seqlen_q
135
+ cu_seqlens_q = F.pad(torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32), (1, 0))
136
+ cu_seqlens_k = cu_seqlens_q.clone()
137
+
138
+ output = flash_attn_varlen_func(
139
+ query,
140
+ key,
141
+ value,
142
+ cu_seqlens_q=cu_seqlens_q,
143
+ cu_seqlens_k=cu_seqlens_k,
144
+ max_seqlen_q=max_seqlen_q,
145
+ max_seqlen_k=max_seqlen_k,
146
+ dropout_p=0.0,
147
+ causal=False,
148
+ softmax_scale=scale,
149
+ )
150
+
151
+ # To merge the tokens
152
+ i_sum = 0;token_sum = 0
153
+ for i_p, length in enumerate(hidden_length):
154
+ tot_token_num = token_lengths[i_p]
155
+ stage_output = output[token_sum : token_sum + tot_token_num]
156
+ stage_output = pad_input(stage_output, encoder_attention_mask[i_p]['indices'], batch_size, encoder_length + length)
157
+ stage_encoder_hidden_output = stage_output[:, :encoder_length]
158
+ stage_hidden_output = stage_output[:, encoder_length:]
159
+ output_hidden[:, i_sum:i_sum+length] = stage_hidden_output
160
+ output_encoder_hidden[i_p::num_stages] = stage_encoder_hidden_output
161
+ token_sum += tot_token_num
162
+ i_sum += length
163
+
164
+ output_hidden = output_hidden.flatten(2, 3)
165
+ output_encoder_hidden = output_encoder_hidden.flatten(2, 3)
166
+
167
+ return output_hidden, output_encoder_hidden
168
+
169
+
170
+ class SequenceParallelVarlenFlashSelfAttentionWithT5Mask:
171
+
172
+ def __init__(self):
173
+ pass
174
+
175
+ def apply_rope(self, xq, xk, freqs_cis):
176
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
177
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
178
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
179
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
180
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
181
+
182
+ def __call__(
183
+ self, query, key, value, encoder_query, encoder_key, encoder_value,
184
+ heads, scale, hidden_length=None, image_rotary_emb=None, encoder_attention_mask=None,
185
+ ):
186
+ assert encoder_attention_mask is not None, "The encoder-hidden mask needed to be set"
187
+
188
+ batch_size = query.shape[0]
189
+ qkv_list = []
190
+ num_stages = len(hidden_length)
191
+
192
+ encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
193
+ qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
194
+
195
+ # To sync the encoder query, key and values
196
+ sp_group = get_sequence_parallel_group()
197
+ sp_group_size = get_sequence_parallel_world_size()
198
+ encoder_qkv = all_to_all(encoder_qkv, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
199
+
200
+ output_hidden = torch.zeros_like(qkv[:,:,0])
201
+ output_encoder_hidden = torch.zeros_like(encoder_qkv[:,:,0])
202
+ encoder_length = encoder_qkv.shape[1]
203
+
204
+ i_sum = 0
205
+ for i_p, length in enumerate(hidden_length):
206
+ # get the query, key, value from padding sequence
207
+ encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
208
+ qkv_tokens = qkv[:, i_sum:i_sum+length]
209
+ qkv_tokens = all_to_all(qkv_tokens, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
210
+ concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) # [bs, pad_seq, 3, nhead, dim]
211
+
212
+ if image_rotary_emb is not None:
213
+ concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = self.apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
214
+
215
+ indices = encoder_attention_mask[i_p]['indices']
216
+ qkv_list.append(index_first_axis(rearrange(concat_qkv_tokens, "b s ... -> (b s) ..."), indices))
217
+ i_sum += length
218
+
219
+ token_lengths = [x_.shape[0] for x_ in qkv_list]
220
+ qkv = torch.cat(qkv_list, dim=0)
221
+ query, key, value = qkv.unbind(1)
222
+
223
+ cu_seqlens = torch.cat([x_['seqlens_in_batch'] for x_ in encoder_attention_mask], dim=0)
224
+ max_seqlen_q = cu_seqlens.max().item()
225
+ max_seqlen_k = max_seqlen_q
226
+ cu_seqlens_q = F.pad(torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32), (1, 0))
227
+ cu_seqlens_k = cu_seqlens_q.clone()
228
+
229
+ output = flash_attn_varlen_func(
230
+ query,
231
+ key,
232
+ value,
233
+ cu_seqlens_q=cu_seqlens_q,
234
+ cu_seqlens_k=cu_seqlens_k,
235
+ max_seqlen_q=max_seqlen_q,
236
+ max_seqlen_k=max_seqlen_k,
237
+ dropout_p=0.0,
238
+ causal=False,
239
+ softmax_scale=scale,
240
+ )
241
+
242
+ # To merge the tokens
243
+ i_sum = 0;token_sum = 0
244
+ for i_p, length in enumerate(hidden_length):
245
+ tot_token_num = token_lengths[i_p]
246
+ stage_output = output[token_sum : token_sum + tot_token_num]
247
+ stage_output = pad_input(stage_output, encoder_attention_mask[i_p]['indices'], batch_size, encoder_length + length * sp_group_size)
248
+ stage_encoder_hidden_output = stage_output[:, :encoder_length]
249
+ stage_hidden_output = stage_output[:, encoder_length:]
250
+ stage_hidden_output = all_to_all(stage_hidden_output, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
251
+ output_hidden[:, i_sum:i_sum+length] = stage_hidden_output
252
+ output_encoder_hidden[i_p::num_stages] = stage_encoder_hidden_output
253
+ token_sum += tot_token_num
254
+ i_sum += length
255
+
256
+ output_encoder_hidden = all_to_all(output_encoder_hidden, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
257
+ output_hidden = output_hidden.flatten(2, 3)
258
+ output_encoder_hidden = output_encoder_hidden.flatten(2, 3)
259
+
260
+ return output_hidden, output_encoder_hidden
261
+
262
+
263
+ class VarlenSelfAttentionWithT5Mask:
264
+
265
+ """
266
+ For chunk stage attention without using flash attention
267
+ """
268
+
269
+ def __init__(self):
270
+ pass
271
+
272
+ def apply_rope(self, xq, xk, freqs_cis):
273
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
274
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
275
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
276
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
277
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
278
+
279
+ def __call__(
280
+ self, query, key, value, encoder_query, encoder_key, encoder_value,
281
+ heads, scale, hidden_length=None, image_rotary_emb=None, attention_mask=None,
282
+ ):
283
+ assert attention_mask is not None, "The attention mask needed to be set"
284
+
285
+ encoder_length = encoder_query.shape[1]
286
+ num_stages = len(hidden_length)
287
+
288
+ encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
289
+ qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
290
+
291
+ i_sum = 0
292
+ output_encoder_hidden_list = []
293
+ output_hidden_list = []
294
+
295
+ for i_p, length in enumerate(hidden_length):
296
+ encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
297
+ qkv_tokens = qkv[:, i_sum:i_sum+length]
298
+ concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) # [bs, tot_seq, 3, nhead, dim]
299
+
300
+ if image_rotary_emb is not None:
301
+ concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = self.apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
302
+
303
+ query, key, value = concat_qkv_tokens.unbind(2) # [bs, tot_seq, nhead, dim]
304
+ query = query.transpose(1, 2)
305
+ key = key.transpose(1, 2)
306
+ value = value.transpose(1, 2)
307
+
308
+ # with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=False, enable_mem_efficient=True):
309
+ stage_hidden_states = F.scaled_dot_product_attention(
310
+ query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask[i_p],
311
+ )
312
+ stage_hidden_states = stage_hidden_states.transpose(1, 2).flatten(2, 3) # [bs, tot_seq, dim]
313
+
314
+ output_encoder_hidden_list.append(stage_hidden_states[:, :encoder_length])
315
+ output_hidden_list.append(stage_hidden_states[:, encoder_length:])
316
+ i_sum += length
317
+
318
+ output_encoder_hidden = torch.stack(output_encoder_hidden_list, dim=1) # [b n s d]
319
+ output_encoder_hidden = rearrange(output_encoder_hidden, 'b n s d -> (b n) s d')
320
+ output_hidden = torch.cat(output_hidden_list, dim=1)
321
+
322
+ return output_hidden, output_encoder_hidden
323
+
324
+
325
+ class SequenceParallelVarlenSelfAttentionWithT5Mask:
326
+ """
327
+ For chunk stage attention without using flash attention
328
+ """
329
+
330
+ def __init__(self):
331
+ pass
332
+
333
+ def apply_rope(self, xq, xk, freqs_cis):
334
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
335
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
336
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
337
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
338
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
339
+
340
+ def __call__(
341
+ self, query, key, value, encoder_query, encoder_key, encoder_value,
342
+ heads, scale, hidden_length=None, image_rotary_emb=None, attention_mask=None,
343
+ ):
344
+ assert attention_mask is not None, "The attention mask needed to be set"
345
+
346
+ num_stages = len(hidden_length)
347
+
348
+ encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
349
+ qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
350
+
351
+ # To sync the encoder query, key and values
352
+ sp_group = get_sequence_parallel_group()
353
+ sp_group_size = get_sequence_parallel_world_size()
354
+ encoder_qkv = all_to_all(encoder_qkv, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
355
+ encoder_length = encoder_qkv.shape[1]
356
+
357
+ i_sum = 0
358
+ output_encoder_hidden_list = []
359
+ output_hidden_list = []
360
+
361
+ for i_p, length in enumerate(hidden_length):
362
+ encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
363
+ qkv_tokens = qkv[:, i_sum:i_sum+length]
364
+ qkv_tokens = all_to_all(qkv_tokens, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
365
+ concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) # [bs, tot_seq, 3, nhead, dim]
366
+
367
+ if image_rotary_emb is not None:
368
+ concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = self.apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
369
+
370
+ query, key, value = concat_qkv_tokens.unbind(2) # [bs, tot_seq, nhead, dim]
371
+ query = query.transpose(1, 2)
372
+ key = key.transpose(1, 2)
373
+ value = value.transpose(1, 2)
374
+
375
+ stage_hidden_states = F.scaled_dot_product_attention(
376
+ query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask[i_p],
377
+ )
378
+ stage_hidden_states = stage_hidden_states.transpose(1, 2) # [bs, tot_seq, nhead, dim]
379
+
380
+ output_encoder_hidden_list.append(stage_hidden_states[:, :encoder_length])
381
+
382
+ output_hidden = stage_hidden_states[:, encoder_length:]
383
+ output_hidden = all_to_all(output_hidden, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
384
+ output_hidden_list.append(output_hidden)
385
+
386
+ i_sum += length
387
+
388
+ output_encoder_hidden = torch.stack(output_encoder_hidden_list, dim=1) # [b n s nhead d]
389
+ output_encoder_hidden = rearrange(output_encoder_hidden, 'b n s h d -> (b n) s h d')
390
+ output_encoder_hidden = all_to_all(output_encoder_hidden, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
391
+ output_encoder_hidden = output_encoder_hidden.flatten(2, 3)
392
+ output_hidden = torch.cat(output_hidden_list, dim=1).flatten(2, 3)
393
+
394
+ return output_hidden, output_encoder_hidden
395
+
396
+
397
+ class JointAttention(nn.Module):
398
+
399
+ def __init__(
400
+ self,
401
+ query_dim: int,
402
+ cross_attention_dim: Optional[int] = None,
403
+ heads: int = 8,
404
+ dim_head: int = 64,
405
+ dropout: float = 0.0,
406
+ bias: bool = False,
407
+ qk_norm: Optional[str] = None,
408
+ added_kv_proj_dim: Optional[int] = None,
409
+ out_bias: bool = True,
410
+ eps: float = 1e-5,
411
+ out_dim: int = None,
412
+ context_pre_only=None,
413
+ use_flash_attn=True,
414
+ ):
415
+ """
416
+ Fixing the QKNorm, following the flux, norm the head dimension
417
+ """
418
+ super().__init__()
419
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
420
+ self.query_dim = query_dim
421
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
422
+ self.use_bias = bias
423
+ self.dropout = dropout
424
+
425
+ self.out_dim = out_dim if out_dim is not None else query_dim
426
+ self.context_pre_only = context_pre_only
427
+
428
+ self.scale = dim_head**-0.5
429
+ self.heads = out_dim // dim_head if out_dim is not None else heads
430
+ self.added_kv_proj_dim = added_kv_proj_dim
431
+
432
+ if qk_norm is None:
433
+ self.norm_q = None
434
+ self.norm_k = None
435
+ elif qk_norm == "layer_norm":
436
+ self.norm_q = nn.LayerNorm(dim_head, eps=eps)
437
+ self.norm_k = nn.LayerNorm(dim_head, eps=eps)
438
+ elif qk_norm == 'rms_norm':
439
+ self.norm_q = RMSNorm(dim_head, eps=eps)
440
+ self.norm_k = RMSNorm(dim_head, eps=eps)
441
+ else:
442
+ raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
443
+
444
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
445
+ self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
446
+ self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
447
+
448
+ if self.added_kv_proj_dim is not None:
449
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
450
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
451
+ self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
452
+
453
+ if qk_norm is None:
454
+ self.norm_add_q = None
455
+ self.norm_add_k = None
456
+ elif qk_norm == "layer_norm":
457
+ self.norm_add_q = nn.LayerNorm(dim_head, eps=eps)
458
+ self.norm_add_k = nn.LayerNorm(dim_head, eps=eps)
459
+ elif qk_norm == 'rms_norm':
460
+ self.norm_add_q = RMSNorm(dim_head, eps=eps)
461
+ self.norm_add_k = RMSNorm(dim_head, eps=eps)
462
+ else:
463
+ raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
464
+
465
+ self.to_out = nn.ModuleList([])
466
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
467
+ self.to_out.append(nn.Dropout(dropout))
468
+
469
+ if not self.context_pre_only:
470
+ self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
471
+
472
+ self.use_flash_attn = use_flash_attn
473
+
474
+ if flash_attn_func is None:
475
+ self.use_flash_attn = False
476
+
477
+ # print(f"Using flash-attention: {self.use_flash_attn}")
478
+ if self.use_flash_attn:
479
+ if is_sequence_parallel_initialized():
480
+ self.var_flash_attn = SequenceParallelVarlenFlashSelfAttentionWithT5Mask()
481
+ else:
482
+ self.var_flash_attn = VarlenFlashSelfAttentionWithT5Mask()
483
+ else:
484
+ if is_sequence_parallel_initialized():
485
+ self.var_len_attn = SequenceParallelVarlenSelfAttentionWithT5Mask()
486
+ else:
487
+ self.var_len_attn = VarlenSelfAttentionWithT5Mask()
488
+
489
+
490
+ def forward(
491
+ self,
492
+ hidden_states: torch.FloatTensor,
493
+ encoder_hidden_states: torch.FloatTensor = None,
494
+ encoder_attention_mask: torch.FloatTensor = None,
495
+ attention_mask: torch.FloatTensor = None, # [B, L, S]
496
+ hidden_length: torch.Tensor = None,
497
+ image_rotary_emb: torch.Tensor = None,
498
+ **kwargs,
499
+ ) -> torch.FloatTensor:
500
+ # This function is only used during training
501
+ # `sample` projections.
502
+ query = self.to_q(hidden_states)
503
+ key = self.to_k(hidden_states)
504
+ value = self.to_v(hidden_states)
505
+
506
+ inner_dim = key.shape[-1]
507
+ head_dim = inner_dim // self.heads
508
+
509
+ query = query.view(query.shape[0], -1, self.heads, head_dim)
510
+ key = key.view(key.shape[0], -1, self.heads, head_dim)
511
+ value = value.view(value.shape[0], -1, self.heads, head_dim)
512
+
513
+ if self.norm_q is not None:
514
+ query = self.norm_q(query)
515
+
516
+ if self.norm_k is not None:
517
+ key = self.norm_k(key)
518
+
519
+ # `context` projections.
520
+ encoder_hidden_states_query_proj = self.add_q_proj(encoder_hidden_states)
521
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
522
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
523
+
524
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
525
+ encoder_hidden_states_query_proj.shape[0], -1, self.heads, head_dim
526
+ )
527
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
528
+ encoder_hidden_states_key_proj.shape[0], -1, self.heads, head_dim
529
+ )
530
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
531
+ encoder_hidden_states_value_proj.shape[0], -1, self.heads, head_dim
532
+ )
533
+
534
+ if self.norm_add_q is not None:
535
+ encoder_hidden_states_query_proj = self.norm_add_q(encoder_hidden_states_query_proj)
536
+
537
+ if self.norm_add_k is not None:
538
+ encoder_hidden_states_key_proj = self.norm_add_k(encoder_hidden_states_key_proj)
539
+
540
+ # To cat the hidden and encoder hidden, perform attention compuataion, and then split
541
+ if self.use_flash_attn:
542
+ hidden_states, encoder_hidden_states = self.var_flash_attn(
543
+ query, key, value,
544
+ encoder_hidden_states_query_proj, encoder_hidden_states_key_proj,
545
+ encoder_hidden_states_value_proj, self.heads, self.scale, hidden_length,
546
+ image_rotary_emb, encoder_attention_mask,
547
+ )
548
+ else:
549
+ hidden_states, encoder_hidden_states = self.var_len_attn(
550
+ query, key, value,
551
+ encoder_hidden_states_query_proj, encoder_hidden_states_key_proj,
552
+ encoder_hidden_states_value_proj, self.heads, self.scale, hidden_length,
553
+ image_rotary_emb, attention_mask,
554
+ )
555
+
556
+ # linear proj
557
+ hidden_states = self.to_out[0](hidden_states)
558
+ # dropout
559
+ hidden_states = self.to_out[1](hidden_states)
560
+ if not self.context_pre_only:
561
+ encoder_hidden_states = self.to_add_out(encoder_hidden_states)
562
+
563
+ return hidden_states, encoder_hidden_states
564
+
565
+
566
+ class JointTransformerBlock(nn.Module):
567
+ r"""
568
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
569
+
570
+ Reference: https://arxiv.org/abs/2403.03206
571
+
572
+ Parameters:
573
+ dim (`int`): The number of channels in the input and output.
574
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
575
+ attention_head_dim (`int`): The number of channels in each head.
576
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
577
+ processing of `context` conditions.
578
+ """
579
+
580
+ def __init__(
581
+ self, dim, num_attention_heads, attention_head_dim, qk_norm=None,
582
+ context_pre_only=False, use_flash_attn=True,
583
+ ):
584
+ super().__init__()
585
+
586
+ self.context_pre_only = context_pre_only
587
+ context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
588
+
589
+ self.norm1 = AdaLayerNormZero(dim)
590
+
591
+ if context_norm_type == "ada_norm_continous":
592
+ self.norm1_context = AdaLayerNormContinuous(
593
+ dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
594
+ )
595
+ elif context_norm_type == "ada_norm_zero":
596
+ self.norm1_context = AdaLayerNormZero(dim)
597
+ else:
598
+ raise ValueError(
599
+ f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
600
+ )
601
+
602
+ self.attn = JointAttention(
603
+ query_dim=dim,
604
+ cross_attention_dim=None,
605
+ added_kv_proj_dim=dim,
606
+ dim_head=attention_head_dim // num_attention_heads,
607
+ heads=num_attention_heads,
608
+ out_dim=attention_head_dim,
609
+ qk_norm=qk_norm,
610
+ context_pre_only=context_pre_only,
611
+ bias=True,
612
+ use_flash_attn=use_flash_attn,
613
+ )
614
+
615
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
616
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
617
+
618
+ if not context_pre_only:
619
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
620
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
621
+ else:
622
+ self.norm2_context = None
623
+ self.ff_context = None
624
+
625
+ def forward(
626
+ self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor,
627
+ encoder_attention_mask: torch.FloatTensor, temb: torch.FloatTensor,
628
+ attention_mask: torch.FloatTensor = None, hidden_length: List = None,
629
+ image_rotary_emb: torch.FloatTensor = None,
630
+ ):
631
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb, hidden_length=hidden_length)
632
+
633
+ if self.context_pre_only:
634
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
635
+ else:
636
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
637
+ encoder_hidden_states, emb=temb,
638
+ )
639
+
640
+ # Attention
641
+ attn_output, context_attn_output = self.attn(
642
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states,
643
+ encoder_attention_mask=encoder_attention_mask, attention_mask=attention_mask,
644
+ hidden_length=hidden_length, image_rotary_emb=image_rotary_emb,
645
+ )
646
+
647
+ # Process attention outputs for the `hidden_states`.
648
+ attn_output = gate_msa * attn_output
649
+ hidden_states = hidden_states + attn_output
650
+
651
+ norm_hidden_states = self.norm2(hidden_states)
652
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
653
+
654
+ ff_output = self.ff(norm_hidden_states)
655
+ ff_output = gate_mlp * ff_output
656
+
657
+ hidden_states = hidden_states + ff_output
658
+
659
+ # Process attention outputs for the `encoder_hidden_states`.
660
+ if self.context_pre_only:
661
+ encoder_hidden_states = None
662
+ else:
663
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
664
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
665
+
666
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
667
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
668
+
669
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
670
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
671
+
672
+ return encoder_hidden_states, hidden_states
pyramid_dit/modeling_normalization.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numbers
2
+ from typing import Dict, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from diffusers.utils import is_torch_version
9
+
10
+
11
+ if is_torch_version(">=", "2.1.0"):
12
+ LayerNorm = nn.LayerNorm
13
+ else:
14
+ # Has optional bias parameter compared to torch layer norm
15
+ # TODO: replace with torch layernorm once min required torch version >= 2.1
16
+ class LayerNorm(nn.Module):
17
+ def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
18
+ super().__init__()
19
+
20
+ self.eps = eps
21
+
22
+ if isinstance(dim, numbers.Integral):
23
+ dim = (dim,)
24
+
25
+ self.dim = torch.Size(dim)
26
+
27
+ if elementwise_affine:
28
+ self.weight = nn.Parameter(torch.ones(dim))
29
+ self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
30
+ else:
31
+ self.weight = None
32
+ self.bias = None
33
+
34
+ def forward(self, input):
35
+ return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
36
+
37
+
38
+ class RMSNorm(nn.Module):
39
+ def __init__(self, dim, eps: float, elementwise_affine: bool = True):
40
+ super().__init__()
41
+
42
+ self.eps = eps
43
+
44
+ if isinstance(dim, numbers.Integral):
45
+ dim = (dim,)
46
+
47
+ self.dim = torch.Size(dim)
48
+
49
+ if elementwise_affine:
50
+ self.weight = nn.Parameter(torch.ones(dim))
51
+ else:
52
+ self.weight = None
53
+
54
+ def forward(self, hidden_states):
55
+ input_dtype = hidden_states.dtype
56
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
57
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
58
+
59
+ if self.weight is not None:
60
+ # convert into half-precision if necessary
61
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
62
+ hidden_states = hidden_states.to(self.weight.dtype)
63
+ hidden_states = hidden_states * self.weight
64
+
65
+ hidden_states = hidden_states.to(input_dtype)
66
+
67
+ return hidden_states
68
+
69
+
70
+ class AdaLayerNormContinuous(nn.Module):
71
+ def __init__(
72
+ self,
73
+ embedding_dim: int,
74
+ conditioning_embedding_dim: int,
75
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
76
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
77
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
78
+ # However, this is how it was implemented in the original code, and it's rather likely you should
79
+ # set `elementwise_affine` to False.
80
+ elementwise_affine=True,
81
+ eps=1e-5,
82
+ bias=True,
83
+ norm_type="layer_norm",
84
+ ):
85
+ super().__init__()
86
+ self.silu = nn.SiLU()
87
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
88
+ if norm_type == "layer_norm":
89
+ self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
90
+ elif norm_type == "rms_norm":
91
+ self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
92
+ else:
93
+ raise ValueError(f"unknown norm_type {norm_type}")
94
+
95
+ def forward_with_pad(self, x: torch.Tensor, conditioning_embedding: torch.Tensor, hidden_length=None) -> torch.Tensor:
96
+ assert hidden_length is not None
97
+
98
+ emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
99
+ batch_emb = torch.zeros_like(x).repeat(1, 1, 2)
100
+
101
+ i_sum = 0
102
+ num_stages = len(hidden_length)
103
+ for i_p, length in enumerate(hidden_length):
104
+ batch_emb[:, i_sum:i_sum+length] = emb[i_p::num_stages][:,None]
105
+ i_sum += length
106
+
107
+ batch_scale, batch_shift = torch.chunk(batch_emb, 2, dim=2)
108
+ x = self.norm(x) * (1 + batch_scale) + batch_shift
109
+ return x
110
+
111
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor, hidden_length=None) -> torch.Tensor:
112
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
113
+ if hidden_length is not None:
114
+ return self.forward_with_pad(x, conditioning_embedding, hidden_length)
115
+ emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
116
+ scale, shift = torch.chunk(emb, 2, dim=1)
117
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
118
+ return x
119
+
120
+
121
+ class AdaLayerNormZero(nn.Module):
122
+ r"""
123
+ Norm layer adaptive layer norm zero (adaLN-Zero).
124
+
125
+ Parameters:
126
+ embedding_dim (`int`): The size of each embedding vector.
127
+ num_embeddings (`int`): The size of the embeddings dictionary.
128
+ """
129
+
130
+ def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None):
131
+ super().__init__()
132
+ self.emb = None
133
+ self.silu = nn.SiLU()
134
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
135
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
136
+
137
+ def forward_with_pad(
138
+ self,
139
+ x: torch.Tensor,
140
+ timestep: Optional[torch.Tensor] = None,
141
+ class_labels: Optional[torch.LongTensor] = None,
142
+ hidden_dtype: Optional[torch.dtype] = None,
143
+ emb: Optional[torch.Tensor] = None,
144
+ hidden_length: Optional[torch.Tensor] = None,
145
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
146
+ # x: [bs, seq_len, dim]
147
+ if self.emb is not None:
148
+ emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
149
+
150
+ emb = self.linear(self.silu(emb))
151
+ batch_emb = torch.zeros_like(x).repeat(1, 1, 6)
152
+
153
+ i_sum = 0
154
+ num_stages = len(hidden_length)
155
+ for i_p, length in enumerate(hidden_length):
156
+ batch_emb[:, i_sum:i_sum+length] = emb[i_p::num_stages][:,None]
157
+ i_sum += length
158
+
159
+ batch_shift_msa, batch_scale_msa, batch_gate_msa, batch_shift_mlp, batch_scale_mlp, batch_gate_mlp = batch_emb.chunk(6, dim=2)
160
+ x = self.norm(x) * (1 + batch_scale_msa) + batch_shift_msa
161
+ return x, batch_gate_msa, batch_shift_mlp, batch_scale_mlp, batch_gate_mlp
162
+
163
+ def forward(
164
+ self,
165
+ x: torch.Tensor,
166
+ timestep: Optional[torch.Tensor] = None,
167
+ class_labels: Optional[torch.LongTensor] = None,
168
+ hidden_dtype: Optional[torch.dtype] = None,
169
+ emb: Optional[torch.Tensor] = None,
170
+ hidden_length: Optional[torch.Tensor] = None,
171
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
172
+ if hidden_length is not None:
173
+ return self.forward_with_pad(x, timestep, class_labels, hidden_dtype, emb, hidden_length)
174
+ if self.emb is not None:
175
+ emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
176
+ emb = self.linear(self.silu(emb))
177
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
178
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
179
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
pyramid_dit/modeling_pyramid_mmdit.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import os
4
+ import torch.nn.functional as F
5
+
6
+ from einops import rearrange
7
+ from diffusers.utils.torch_utils import randn_tensor
8
+ from diffusers.models.modeling_utils import ModelMixin
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+ from diffusers.utils import is_torch_version
11
+ from typing import Any, Callable, Dict, List, Optional, Union
12
+ from tqdm import tqdm
13
+
14
+ from .modeling_embedding import PatchEmbed3D, CombinedTimestepConditionEmbeddings
15
+ from .modeling_normalization import AdaLayerNormContinuous
16
+ from .modeling_mmdit_block import JointTransformerBlock
17
+
18
+ from trainer_misc import (
19
+ is_sequence_parallel_initialized,
20
+ get_sequence_parallel_group,
21
+ get_sequence_parallel_world_size,
22
+ get_sequence_parallel_rank,
23
+ all_to_all,
24
+ )
25
+
26
+ from IPython import embed
27
+
28
+
29
+ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
30
+ assert dim % 2 == 0, "The dimension must be even."
31
+
32
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
33
+ omega = 1.0 / (theta**scale)
34
+
35
+ batch_size, seq_length = pos.shape
36
+ out = torch.einsum("...n,d->...nd", pos, omega)
37
+ cos_out = torch.cos(out)
38
+ sin_out = torch.sin(out)
39
+
40
+ stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
41
+ out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
42
+ return out.float()
43
+
44
+
45
+ class EmbedNDRoPE(nn.Module):
46
+ def __init__(self, dim: int, theta: int, axes_dim: List[int]):
47
+ super().__init__()
48
+ self.dim = dim
49
+ self.theta = theta
50
+ self.axes_dim = axes_dim
51
+
52
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
53
+ n_axes = ids.shape[-1]
54
+ emb = torch.cat(
55
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
56
+ dim=-3,
57
+ )
58
+ return emb.unsqueeze(2)
59
+
60
+
61
+ class PyramidDiffusionMMDiT(ModelMixin, ConfigMixin):
62
+ _supports_gradient_checkpointing = True
63
+
64
+ @register_to_config
65
+ def __init__(
66
+ self,
67
+ sample_size: int = 128,
68
+ patch_size: int = 2,
69
+ in_channels: int = 16,
70
+ num_layers: int = 24,
71
+ attention_head_dim: int = 64,
72
+ num_attention_heads: int = 24,
73
+ caption_projection_dim: int = 1152,
74
+ pooled_projection_dim: int = 2048,
75
+ pos_embed_max_size: int = 192,
76
+ max_num_frames: int = 200,
77
+ qk_norm: str = 'rms_norm',
78
+ pos_embed_type: str = 'rope',
79
+ temp_pos_embed_type: str = 'sincos',
80
+ joint_attention_dim: int = 4096,
81
+ use_gradient_checkpointing: bool = False,
82
+ use_flash_attn: bool = True,
83
+ use_temporal_causal: bool = False,
84
+ use_t5_mask: bool = False,
85
+ add_temp_pos_embed: bool = False,
86
+ interp_condition_pos: bool = False,
87
+ ):
88
+ super().__init__()
89
+
90
+ self.out_channels = in_channels
91
+ self.inner_dim = num_attention_heads * attention_head_dim
92
+ assert temp_pos_embed_type in ['rope', 'sincos']
93
+
94
+ # The input latent embeder, using the name pos_embed to remain the same with SD#
95
+ self.pos_embed = PatchEmbed3D(
96
+ height=sample_size,
97
+ width=sample_size,
98
+ patch_size=patch_size,
99
+ in_channels=in_channels,
100
+ embed_dim=self.inner_dim,
101
+ pos_embed_max_size=pos_embed_max_size, # hard-code for now.
102
+ max_num_frames=max_num_frames,
103
+ pos_embed_type=pos_embed_type,
104
+ temp_pos_embed_type=temp_pos_embed_type,
105
+ add_temp_pos_embed=add_temp_pos_embed,
106
+ interp_condition_pos=interp_condition_pos,
107
+ )
108
+
109
+ # The RoPE EMbedding
110
+ if pos_embed_type == 'rope':
111
+ self.rope_embed = EmbedNDRoPE(self.inner_dim, 10000, axes_dim=[16, 24, 24])
112
+ else:
113
+ self.rope_embed = None
114
+
115
+ if temp_pos_embed_type == 'rope':
116
+ self.temp_rope_embed = EmbedNDRoPE(self.inner_dim, 10000, axes_dim=[attention_head_dim])
117
+ else:
118
+ self.temp_rope_embed = None
119
+
120
+ self.time_text_embed = CombinedTimestepConditionEmbeddings(
121
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim,
122
+ )
123
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
124
+
125
+ self.transformer_blocks = nn.ModuleList(
126
+ [
127
+ JointTransformerBlock(
128
+ dim=self.inner_dim,
129
+ num_attention_heads=num_attention_heads,
130
+ attention_head_dim=self.inner_dim,
131
+ qk_norm=qk_norm,
132
+ context_pre_only=i == num_layers - 1,
133
+ use_flash_attn=use_flash_attn,
134
+ )
135
+ for i in range(num_layers)
136
+ ]
137
+ )
138
+
139
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
140
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
141
+ self.gradient_checkpointing = use_gradient_checkpointing
142
+ self.patch_size = patch_size
143
+ self.use_flash_attn = use_flash_attn
144
+ self.use_temporal_causal = use_temporal_causal
145
+ self.pos_embed_type = pos_embed_type
146
+ self.temp_pos_embed_type = temp_pos_embed_type
147
+ self.add_temp_pos_embed = add_temp_pos_embed
148
+
149
+ if self.use_temporal_causal:
150
+ print("Using temporal causal attention")
151
+ assert self.use_flash_attn is False, "The flash attention does not support temporal causal"
152
+
153
+ if interp_condition_pos:
154
+ print("We interp the position embedding of condition latents")
155
+
156
+ # init weights
157
+ self.initialize_weights()
158
+
159
+ def initialize_weights(self):
160
+ # Initialize transformer layers:
161
+ def _basic_init(module):
162
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)):
163
+ torch.nn.init.xavier_uniform_(module.weight)
164
+ if module.bias is not None:
165
+ nn.init.constant_(module.bias, 0)
166
+ self.apply(_basic_init)
167
+
168
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
169
+ w = self.pos_embed.proj.weight.data
170
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
171
+ nn.init.constant_(self.pos_embed.proj.bias, 0)
172
+
173
+ # Initialize all the conditioning to normal init
174
+ nn.init.normal_(self.time_text_embed.timestep_embedder.linear_1.weight, std=0.02)
175
+ nn.init.normal_(self.time_text_embed.timestep_embedder.linear_2.weight, std=0.02)
176
+ nn.init.normal_(self.time_text_embed.text_embedder.linear_1.weight, std=0.02)
177
+ nn.init.normal_(self.time_text_embed.text_embedder.linear_2.weight, std=0.02)
178
+ nn.init.normal_(self.context_embedder.weight, std=0.02)
179
+
180
+ # Zero-out adaLN modulation layers in DiT blocks:
181
+ for block in self.transformer_blocks:
182
+ nn.init.constant_(block.norm1.linear.weight, 0)
183
+ nn.init.constant_(block.norm1.linear.bias, 0)
184
+ nn.init.constant_(block.norm1_context.linear.weight, 0)
185
+ nn.init.constant_(block.norm1_context.linear.bias, 0)
186
+
187
+ # Zero-out output layers:
188
+ nn.init.constant_(self.norm_out.linear.weight, 0)
189
+ nn.init.constant_(self.norm_out.linear.bias, 0)
190
+ nn.init.constant_(self.proj_out.weight, 0)
191
+ nn.init.constant_(self.proj_out.bias, 0)
192
+
193
+ @torch.no_grad()
194
+ def _prepare_latent_image_ids(self, batch_size, temp, height, width, device):
195
+ latent_image_ids = torch.zeros(temp, height, width, 3)
196
+ latent_image_ids[..., 0] = latent_image_ids[..., 0] + torch.arange(temp)[:, None, None]
197
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[None, :, None]
198
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, None, :]
199
+
200
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1, 1)
201
+ latent_image_ids = rearrange(latent_image_ids, 'b t h w c -> b (t h w) c')
202
+ return latent_image_ids.to(device=device)
203
+
204
+ @torch.no_grad()
205
+ def _prepare_pyramid_latent_image_ids(self, batch_size, temp_list, height_list, width_list, device):
206
+ base_width = width_list[-1]; base_height = height_list[-1]
207
+ assert base_width == max(width_list)
208
+ assert base_height == max(height_list)
209
+
210
+ image_ids_list = []
211
+ for temp, height, width in zip(temp_list, height_list, width_list):
212
+ latent_image_ids = torch.zeros(temp, height, width, 3)
213
+
214
+ if height != base_height:
215
+ height_pos = F.interpolate(torch.arange(base_height)[None, None, :].float(), height, mode='linear').squeeze(0, 1)
216
+ else:
217
+ height_pos = torch.arange(base_height).float()
218
+ if width != base_width:
219
+ width_pos = F.interpolate(torch.arange(base_width)[None, None, :].float(), width, mode='linear').squeeze(0, 1)
220
+ else:
221
+ width_pos = torch.arange(base_width).float()
222
+
223
+ latent_image_ids[..., 0] = latent_image_ids[..., 0] + torch.arange(temp)[:, None, None]
224
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + height_pos[None, :, None]
225
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + width_pos[None, None, :]
226
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1, 1)
227
+ latent_image_ids = rearrange(latent_image_ids, 'b t h w c -> b (t h w) c').to(device)
228
+ image_ids_list.append(latent_image_ids)
229
+
230
+ return image_ids_list
231
+
232
+ @torch.no_grad()
233
+ def _prepare_temporal_rope_ids(self, batch_size, temp, height, width, device, start_time_stamp=0):
234
+ latent_image_ids = torch.zeros(temp, height, width, 1)
235
+ latent_image_ids[..., 0] = latent_image_ids[..., 0] + torch.arange(start_time_stamp, start_time_stamp + temp)[:, None, None]
236
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1, 1)
237
+ latent_image_ids = rearrange(latent_image_ids, 'b t h w c -> b (t h w) c')
238
+ return latent_image_ids.to(device=device)
239
+
240
+ @torch.no_grad()
241
+ def _prepare_pyramid_temporal_rope_ids(self, sample, batch_size, device):
242
+ image_ids_list = []
243
+
244
+ for i_b, sample_ in enumerate(sample):
245
+ if not isinstance(sample_, list):
246
+ sample_ = [sample_]
247
+
248
+ cur_image_ids = []
249
+ start_time_stamp = 0
250
+
251
+ for clip_ in sample_:
252
+ _, _, temp, height, width = clip_.shape
253
+ height = height // self.patch_size
254
+ width = width // self.patch_size
255
+ cur_image_ids.append(self._prepare_temporal_rope_ids(batch_size, temp, height, width, device, start_time_stamp=start_time_stamp))
256
+ start_time_stamp += temp
257
+
258
+ cur_image_ids = torch.cat(cur_image_ids, dim=1)
259
+ image_ids_list.append(cur_image_ids)
260
+
261
+ return image_ids_list
262
+
263
+ def merge_input(self, sample, encoder_hidden_length, encoder_attention_mask):
264
+ """
265
+ Merge the input video with different resolutions into one sequence
266
+ Sample: From low resolution to high resolution
267
+ """
268
+ if isinstance(sample[0], list):
269
+ device = sample[0][-1].device
270
+ pad_batch_size = sample[0][-1].shape[0]
271
+ else:
272
+ device = sample[0].device
273
+ pad_batch_size = sample[0].shape[0]
274
+
275
+ num_stages = len(sample)
276
+ height_list = [];width_list = [];temp_list = []
277
+ trainable_token_list = []
278
+
279
+ for i_b, sample_ in enumerate(sample):
280
+ if isinstance(sample_, list):
281
+ sample_ = sample_[-1]
282
+ _, _, temp, height, width = sample_.shape
283
+ height = height // self.patch_size
284
+ width = width // self.patch_size
285
+ temp_list.append(temp)
286
+ height_list.append(height)
287
+ width_list.append(width)
288
+ trainable_token_list.append(height * width * temp)
289
+
290
+ # prepare the RoPE embedding if needed
291
+ if self.pos_embed_type == 'rope':
292
+ # TODO: support the 3D Rope for video
293
+ raise NotImplementedError("Not compatible with video generation now")
294
+ text_ids = torch.zeros(pad_batch_size, encoder_hidden_length, 3).to(device=device)
295
+ image_ids_list = self._prepare_pyramid_latent_image_ids(pad_batch_size, temp_list, height_list, width_list, device)
296
+ input_ids_list = [torch.cat([text_ids, image_ids], dim=1) for image_ids in image_ids_list]
297
+ image_rotary_emb = [self.rope_embed(input_ids) for input_ids in input_ids_list] # [bs, seq_len, 1, head_dim // 2, 2, 2]
298
+ else:
299
+ if self.temp_pos_embed_type == 'rope' and self.add_temp_pos_embed:
300
+ image_ids_list = self._prepare_pyramid_temporal_rope_ids(sample, pad_batch_size, device)
301
+ text_ids = torch.zeros(pad_batch_size, encoder_attention_mask.shape[1], 1).to(device=device)
302
+ input_ids_list = [torch.cat([text_ids, image_ids], dim=1) for image_ids in image_ids_list]
303
+ image_rotary_emb = [self.temp_rope_embed(input_ids) for input_ids in input_ids_list] # [bs, seq_len, 1, head_dim // 2, 2, 2]
304
+
305
+ if is_sequence_parallel_initialized():
306
+ sp_group = get_sequence_parallel_group()
307
+ sp_group_size = get_sequence_parallel_world_size()
308
+ image_rotary_emb = [all_to_all(x_.repeat(1, 1, sp_group_size, 1, 1, 1), sp_group, sp_group_size, scatter_dim=2, gather_dim=0) for x_ in image_rotary_emb]
309
+ input_ids_list = [all_to_all(input_ids.repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0) for input_ids in input_ids_list]
310
+
311
+ else:
312
+ image_rotary_emb = None
313
+
314
+ hidden_states = self.pos_embed(sample) # hidden states is a list of [b c t h w] b = real_b // num_stages
315
+ hidden_length = []
316
+
317
+ for i_b in range(num_stages):
318
+ hidden_length.append(hidden_states[i_b].shape[1])
319
+
320
+ # prepare the attention mask
321
+ if self.use_flash_attn:
322
+ attention_mask = None
323
+ indices_list = []
324
+ for i_p, length in enumerate(hidden_length):
325
+ pad_attention_mask = torch.ones((pad_batch_size, length), dtype=encoder_attention_mask.dtype).to(device)
326
+ pad_attention_mask = torch.cat([encoder_attention_mask[i_p::num_stages], pad_attention_mask], dim=1)
327
+
328
+ if is_sequence_parallel_initialized():
329
+ sp_group = get_sequence_parallel_group()
330
+ sp_group_size = get_sequence_parallel_world_size()
331
+ pad_attention_mask = all_to_all(pad_attention_mask.unsqueeze(2).repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0)
332
+ pad_attention_mask = pad_attention_mask.squeeze(2)
333
+
334
+ seqlens_in_batch = pad_attention_mask.sum(dim=-1, dtype=torch.int32)
335
+ indices = torch.nonzero(pad_attention_mask.flatten(), as_tuple=False).flatten()
336
+
337
+ indices_list.append(
338
+ {
339
+ 'indices': indices,
340
+ 'seqlens_in_batch': seqlens_in_batch,
341
+ }
342
+ )
343
+ encoder_attention_mask = indices_list
344
+ else:
345
+ assert encoder_attention_mask.shape[1] == encoder_hidden_length
346
+ real_batch_size = encoder_attention_mask.shape[0]
347
+ # prepare text ids
348
+ text_ids = torch.arange(1, real_batch_size + 1, dtype=encoder_attention_mask.dtype).unsqueeze(1).repeat(1, encoder_hidden_length)
349
+ text_ids = text_ids.to(device)
350
+ text_ids[encoder_attention_mask == 0] = 0
351
+
352
+ # prepare image ids
353
+ image_ids = torch.arange(1, real_batch_size + 1, dtype=encoder_attention_mask.dtype).unsqueeze(1).repeat(1, max(hidden_length))
354
+ image_ids = image_ids.to(device)
355
+ image_ids_list = []
356
+ for i_p, length in enumerate(hidden_length):
357
+ image_ids_list.append(image_ids[i_p::num_stages][:, :length])
358
+
359
+ if is_sequence_parallel_initialized():
360
+ sp_group = get_sequence_parallel_group()
361
+ sp_group_size = get_sequence_parallel_world_size()
362
+ text_ids = all_to_all(text_ids.unsqueeze(2).repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0).squeeze(2)
363
+ image_ids_list = [all_to_all(image_ids_.unsqueeze(2).repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0).squeeze(2) for image_ids_ in image_ids_list]
364
+
365
+ attention_mask = []
366
+ for i_p in range(len(hidden_length)):
367
+ image_ids = image_ids_list[i_p]
368
+ token_ids = torch.cat([text_ids[i_p::num_stages], image_ids], dim=1)
369
+ stage_attention_mask = rearrange(token_ids, 'b i -> b 1 i 1') == rearrange(token_ids, 'b j -> b 1 1 j') # [bs, 1, q_len, k_len]
370
+ if self.use_temporal_causal:
371
+ input_order_ids = input_ids_list[i_p].squeeze(2)
372
+ temporal_causal_mask = rearrange(input_order_ids, 'b i -> b 1 i 1') >= rearrange(input_order_ids, 'b j -> b 1 1 j')
373
+ stage_attention_mask = stage_attention_mask & temporal_causal_mask
374
+ attention_mask.append(stage_attention_mask)
375
+
376
+ return hidden_states, hidden_length, temp_list, height_list, width_list, trainable_token_list, encoder_attention_mask, attention_mask, image_rotary_emb
377
+
378
+ def split_output(self, batch_hidden_states, hidden_length, temps, heights, widths, trainable_token_list):
379
+ # To split the hidden states
380
+ batch_size = batch_hidden_states.shape[0]
381
+ output_hidden_list = []
382
+ batch_hidden_states = torch.split(batch_hidden_states, hidden_length, dim=1)
383
+
384
+ if is_sequence_parallel_initialized():
385
+ sp_group_size = get_sequence_parallel_world_size()
386
+ batch_size = batch_size // sp_group_size
387
+
388
+ for i_p, length in enumerate(hidden_length):
389
+ width, height, temp = widths[i_p], heights[i_p], temps[i_p]
390
+ trainable_token_num = trainable_token_list[i_p]
391
+ hidden_states = batch_hidden_states[i_p]
392
+
393
+ if is_sequence_parallel_initialized():
394
+ sp_group = get_sequence_parallel_group()
395
+ sp_group_size = get_sequence_parallel_world_size()
396
+ hidden_states = all_to_all(hidden_states, sp_group, sp_group_size, scatter_dim=0, gather_dim=1)
397
+
398
+ # only the trainable token are taking part in loss computation
399
+ hidden_states = hidden_states[:, -trainable_token_num:]
400
+
401
+ # unpatchify
402
+ hidden_states = hidden_states.reshape(
403
+ shape=(batch_size, temp, height, width, self.patch_size, self.patch_size, self.out_channels)
404
+ )
405
+ hidden_states = rearrange(hidden_states, "b t h w p1 p2 c -> b t (h p1) (w p2) c")
406
+ hidden_states = rearrange(hidden_states, "b t h w c -> b c t h w")
407
+ output_hidden_list.append(hidden_states)
408
+
409
+ return output_hidden_list
410
+
411
+ def forward(
412
+ self,
413
+ sample: torch.FloatTensor, # [num_stages]
414
+ encoder_hidden_states: torch.FloatTensor = None,
415
+ encoder_attention_mask: torch.FloatTensor = None,
416
+ pooled_projections: torch.FloatTensor = None,
417
+ timestep_ratio: torch.FloatTensor = None,
418
+ ):
419
+ # Get the timestep embedding
420
+ temb = self.time_text_embed(timestep_ratio, pooled_projections)
421
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
422
+ encoder_hidden_length = encoder_hidden_states.shape[1]
423
+
424
+ # Get the input sequence
425
+ hidden_states, hidden_length, temps, heights, widths, trainable_token_list, encoder_attention_mask, \
426
+ attention_mask, image_rotary_emb = self.merge_input(sample, encoder_hidden_length, encoder_attention_mask)
427
+
428
+ # split the long latents if necessary
429
+ if is_sequence_parallel_initialized():
430
+ sp_group = get_sequence_parallel_group()
431
+ sp_group_size = get_sequence_parallel_world_size()
432
+
433
+ # sync the input hidden states
434
+ batch_hidden_states = []
435
+ for i_p, hidden_states_ in enumerate(hidden_states):
436
+ assert hidden_states_.shape[1] % sp_group_size == 0, "The sequence length should be divided by sequence parallel size"
437
+ hidden_states_ = all_to_all(hidden_states_, sp_group, sp_group_size, scatter_dim=1, gather_dim=0)
438
+ hidden_length[i_p] = hidden_length[i_p] // sp_group_size
439
+ batch_hidden_states.append(hidden_states_)
440
+
441
+ # sync the encoder hidden states
442
+ hidden_states = torch.cat(batch_hidden_states, dim=1)
443
+ encoder_hidden_states = all_to_all(encoder_hidden_states, sp_group, sp_group_size, scatter_dim=1, gather_dim=0)
444
+ temb = all_to_all(temb.unsqueeze(1).repeat(1, sp_group_size, 1), sp_group, sp_group_size, scatter_dim=1, gather_dim=0)
445
+ temb = temb.squeeze(1)
446
+ else:
447
+ hidden_states = torch.cat(hidden_states, dim=1)
448
+
449
+ # print(hidden_length)
450
+ for i_b, block in enumerate(self.transformer_blocks):
451
+ if self.training and self.gradient_checkpointing and (i_b >= 2):
452
+ def create_custom_forward(module):
453
+ def custom_forward(*inputs):
454
+ return module(*inputs)
455
+
456
+ return custom_forward
457
+
458
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
459
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
460
+ create_custom_forward(block),
461
+ hidden_states,
462
+ encoder_hidden_states,
463
+ encoder_attention_mask,
464
+ temb,
465
+ attention_mask,
466
+ hidden_length,
467
+ image_rotary_emb,
468
+ **ckpt_kwargs,
469
+ )
470
+
471
+ else:
472
+ encoder_hidden_states, hidden_states = block(
473
+ hidden_states=hidden_states,
474
+ encoder_hidden_states=encoder_hidden_states,
475
+ encoder_attention_mask=encoder_attention_mask,
476
+ temb=temb,
477
+ attention_mask=attention_mask,
478
+ hidden_length=hidden_length,
479
+ image_rotary_emb=image_rotary_emb,
480
+ )
481
+
482
+ hidden_states = self.norm_out(hidden_states, temb, hidden_length=hidden_length)
483
+ hidden_states = self.proj_out(hidden_states)
484
+
485
+ output = self.split_output(hidden_states, hidden_length, temps, heights, widths, trainable_token_list)
486
+
487
+ return output
pyramid_dit/modeling_text_encoder.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import os
4
+
5
+ from transformers import (
6
+ CLIPTextModelWithProjection,
7
+ CLIPTokenizer,
8
+ T5EncoderModel,
9
+ T5TokenizerFast,
10
+ )
11
+
12
+ from typing import Any, Callable, Dict, List, Optional, Union
13
+
14
+
15
+ class SD3TextEncoderWithMask(nn.Module):
16
+ def __init__(self, model_path, torch_dtype):
17
+ super().__init__()
18
+ # CLIP-L
19
+ self.tokenizer = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer'))
20
+ self.tokenizer_max_length = self.tokenizer.model_max_length
21
+ self.text_encoder = CLIPTextModelWithProjection.from_pretrained(os.path.join(model_path, 'text_encoder'), torch_dtype=torch_dtype)
22
+
23
+ # CLIP-G
24
+ self.tokenizer_2 = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer_2'))
25
+ self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(os.path.join(model_path, 'text_encoder_2'), torch_dtype=torch_dtype)
26
+
27
+ # T5
28
+ self.tokenizer_3 = T5TokenizerFast.from_pretrained(os.path.join(model_path, 'tokenizer_3'))
29
+ self.text_encoder_3 = T5EncoderModel.from_pretrained(os.path.join(model_path, 'text_encoder_3'), torch_dtype=torch_dtype)
30
+
31
+ self._freeze()
32
+
33
+ def _freeze(self):
34
+ for param in self.parameters():
35
+ param.requires_grad = False
36
+
37
+ def _get_t5_prompt_embeds(
38
+ self,
39
+ prompt: Union[str, List[str]] = None,
40
+ num_images_per_prompt: int = 1,
41
+ device: Optional[torch.device] = None,
42
+ max_sequence_length: int = 128,
43
+ ):
44
+ prompt = [prompt] if isinstance(prompt, str) else prompt
45
+ batch_size = len(prompt)
46
+
47
+ text_inputs = self.tokenizer_3(
48
+ prompt,
49
+ padding="max_length",
50
+ max_length=max_sequence_length,
51
+ truncation=True,
52
+ add_special_tokens=True,
53
+ return_tensors="pt",
54
+ )
55
+ text_input_ids = text_inputs.input_ids
56
+ prompt_attention_mask = text_inputs.attention_mask
57
+ prompt_attention_mask = prompt_attention_mask.to(device)
58
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
59
+ dtype = self.text_encoder_3.dtype
60
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
61
+
62
+ _, seq_len, _ = prompt_embeds.shape
63
+
64
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
65
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
66
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
67
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
68
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
69
+
70
+ return prompt_embeds, prompt_attention_mask
71
+
72
+ def _get_clip_prompt_embeds(
73
+ self,
74
+ prompt: Union[str, List[str]],
75
+ num_images_per_prompt: int = 1,
76
+ device: Optional[torch.device] = None,
77
+ clip_skip: Optional[int] = None,
78
+ clip_model_index: int = 0,
79
+ ):
80
+
81
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
82
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
83
+
84
+ tokenizer = clip_tokenizers[clip_model_index]
85
+ text_encoder = clip_text_encoders[clip_model_index]
86
+
87
+ batch_size = len(prompt)
88
+
89
+ text_inputs = tokenizer(
90
+ prompt,
91
+ padding="max_length",
92
+ max_length=self.tokenizer_max_length,
93
+ truncation=True,
94
+ return_tensors="pt",
95
+ )
96
+
97
+ text_input_ids = text_inputs.input_ids
98
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
99
+ pooled_prompt_embeds = prompt_embeds[0]
100
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
101
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
102
+
103
+ return pooled_prompt_embeds
104
+
105
+ def encode_prompt(self,
106
+ prompt,
107
+ num_images_per_prompt=1,
108
+ clip_skip: Optional[int] = None,
109
+ device=None,
110
+ ):
111
+ prompt = [prompt] if isinstance(prompt, str) else prompt
112
+
113
+ pooled_prompt_embed = self._get_clip_prompt_embeds(
114
+ prompt=prompt,
115
+ device=device,
116
+ num_images_per_prompt=num_images_per_prompt,
117
+ clip_skip=clip_skip,
118
+ clip_model_index=0,
119
+ )
120
+ pooled_prompt_2_embed = self._get_clip_prompt_embeds(
121
+ prompt=prompt,
122
+ device=device,
123
+ num_images_per_prompt=num_images_per_prompt,
124
+ clip_skip=clip_skip,
125
+ clip_model_index=1,
126
+ )
127
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
128
+
129
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
130
+ prompt=prompt,
131
+ num_images_per_prompt=num_images_per_prompt,
132
+ device=device,
133
+ )
134
+ return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
135
+
136
+ def forward(self, input_prompts, device):
137
+ with torch.no_grad():
138
+ prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.encode_prompt(input_prompts, 1, clip_skip=None, device=device)
139
+
140
+ return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
pyramid_dit/pyramid_dit_for_video_gen_pipeline.py ADDED
@@ -0,0 +1,672 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import sys
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from collections import OrderedDict
8
+ from einops import rearrange
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+ import numpy as np
11
+ import math
12
+ import random
13
+ import PIL
14
+ from PIL import Image
15
+ from tqdm import tqdm
16
+ from torchvision import transforms
17
+ from copy import deepcopy
18
+ from typing import Any, Callable, Dict, List, Optional, Union
19
+ from accelerate import Accelerator
20
+ from diffusion_schedulers import PyramidFlowMatchEulerDiscreteScheduler
21
+ from video_vae.modeling_causal_vae import CausalVideoVAE
22
+
23
+ from trainer_misc import (
24
+ all_to_all,
25
+ is_sequence_parallel_initialized,
26
+ get_sequence_parallel_group,
27
+ get_sequence_parallel_group_rank,
28
+ get_sequence_parallel_rank,
29
+ get_sequence_parallel_world_size,
30
+ get_rank,
31
+ )
32
+
33
+ from .modeling_pyramid_mmdit import PyramidDiffusionMMDiT
34
+ from .modeling_text_encoder import SD3TextEncoderWithMask
35
+
36
+
37
+ def compute_density_for_timestep_sampling(
38
+ weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
39
+ ):
40
+ if weighting_scheme == "logit_normal":
41
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
42
+ u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
43
+ u = torch.nn.functional.sigmoid(u)
44
+ elif weighting_scheme == "mode":
45
+ u = torch.rand(size=(batch_size,), device="cpu")
46
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
47
+ else:
48
+ u = torch.rand(size=(batch_size,), device="cpu")
49
+ return u
50
+
51
+
52
+ class PyramidDiTForVideoGeneration:
53
+ """
54
+ The pyramid dit for both image and video generation, The running class wrapper
55
+ This class is mainly for fixed unit implementation: 1 + n + n + n
56
+ """
57
+ def __init__(self, model_path, model_dtype='bf16', use_gradient_checkpointing=False, return_log=True,
58
+ model_variant="diffusion_transformer_768p", timestep_shift=1.0, stage_range=[0, 1/3, 2/3, 1],
59
+ sample_ratios=[1, 1, 1], scheduler_gamma=1/3, use_mixed_training=False, use_flash_attn=False,
60
+ load_text_encoder=True, load_vae=True, max_temporal_length=31, frame_per_unit=1, use_temporal_causal=True,
61
+ corrupt_ratio=1/3, interp_condition_pos=True, stages=[1, 2, 4], **kwargs,
62
+ ):
63
+ super().__init__()
64
+
65
+ if model_dtype == 'bf16':
66
+ torch_dtype = torch.bfloat16
67
+ elif model_dtype == 'fp16':
68
+ torch_dtype = torch.float16
69
+ else:
70
+ torch_dtype = torch.float32
71
+
72
+ self.stages = stages
73
+ self.sample_ratios = sample_ratios
74
+ self.corrupt_ratio = corrupt_ratio
75
+
76
+ dit_path = os.path.join(model_path, model_variant)
77
+
78
+ # The dit
79
+ if use_mixed_training:
80
+ print("using mixed precision training, do not explicitly casting models")
81
+ self.dit = PyramidDiffusionMMDiT.from_pretrained(
82
+ dit_path, use_gradient_checkpointing=use_gradient_checkpointing,
83
+ use_flash_attn=use_flash_attn, use_t5_mask=True,
84
+ add_temp_pos_embed=True, temp_pos_embed_type='rope',
85
+ use_temporal_causal=use_temporal_causal, interp_condition_pos=interp_condition_pos,
86
+ )
87
+ else:
88
+ print("using half precision")
89
+ self.dit = PyramidDiffusionMMDiT.from_pretrained(
90
+ dit_path, torch_dtype=torch_dtype,
91
+ use_gradient_checkpointing=use_gradient_checkpointing,
92
+ use_flash_attn=use_flash_attn, use_t5_mask=True,
93
+ add_temp_pos_embed=True, temp_pos_embed_type='rope',
94
+ use_temporal_causal=use_temporal_causal, interp_condition_pos=interp_condition_pos,
95
+ )
96
+
97
+ # The text encoder
98
+ if load_text_encoder:
99
+ self.text_encoder = SD3TextEncoderWithMask(model_path, torch_dtype=torch_dtype)
100
+ else:
101
+ self.text_encoder = None
102
+
103
+ # The base video vae decoder
104
+ if load_vae:
105
+ self.vae = CausalVideoVAE.from_pretrained(os.path.join(model_path, 'causal_video_vae'), torch_dtype=torch_dtype, interpolate=False)
106
+ # Freeze vae
107
+ for parameter in self.vae.parameters():
108
+ parameter.requires_grad = False
109
+ else:
110
+ self.vae = None
111
+
112
+ # For the image latent
113
+ self.vae_shift_factor = 0.1490
114
+ self.vae_scale_factor = 1 / 1.8415
115
+
116
+ # For the video latent
117
+ self.vae_video_shift_factor = -0.2343
118
+ self.vae_video_scale_factor = 1 / 3.0986
119
+
120
+ self.downsample = 8
121
+
122
+ # Configure the video training hyper-parameters
123
+ # The video sequence: one frame + N * unit
124
+ self.frame_per_unit = frame_per_unit
125
+ self.max_temporal_length = max_temporal_length
126
+ assert (max_temporal_length - 1) % frame_per_unit == 0, "The frame number should be divided by the frame number per unit"
127
+ self.num_units_per_video = 1 + ((max_temporal_length - 1) // frame_per_unit) + int(sum(sample_ratios))
128
+
129
+ self.scheduler = PyramidFlowMatchEulerDiscreteScheduler(
130
+ shift=timestep_shift, stages=len(self.stages),
131
+ stage_range=stage_range, gamma=scheduler_gamma,
132
+ )
133
+ print(f"The start sigmas and end sigmas of each stage is Start: {self.scheduler.start_sigmas}, End: {self.scheduler.end_sigmas}, Ori_start: {self.scheduler.ori_start_sigmas}")
134
+
135
+ self.cfg_rate = 0.1
136
+ self.return_log = return_log
137
+ self.use_flash_attn = use_flash_attn
138
+
139
+ def load_checkpoint(self, checkpoint_path, model_key='model', **kwargs):
140
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
141
+ dit_checkpoint = OrderedDict()
142
+ for key in checkpoint:
143
+ if key.startswith('vae') or key.startswith('text_encoder'):
144
+ continue
145
+ if key.startswith('dit'):
146
+ new_key = key.split('.')
147
+ new_key = '.'.join(new_key[1:])
148
+ dit_checkpoint[new_key] = checkpoint[key]
149
+ else:
150
+ dit_checkpoint[key] = checkpoint[key]
151
+
152
+ load_result = self.dit.load_state_dict(dit_checkpoint, strict=True)
153
+ print(f"Load checkpoint from {checkpoint_path}, load result: {load_result}")
154
+
155
+ def load_vae_checkpoint(self, vae_checkpoint_path, model_key='model'):
156
+ checkpoint = torch.load(vae_checkpoint_path, map_location='cpu')
157
+ checkpoint = checkpoint[model_key]
158
+ loaded_checkpoint = OrderedDict()
159
+
160
+ for key in checkpoint.keys():
161
+ if key.startswith('vae.'):
162
+ new_key = key.split('.')
163
+ new_key = '.'.join(new_key[1:])
164
+ loaded_checkpoint[new_key] = checkpoint[key]
165
+
166
+ load_result = self.vae.load_state_dict(loaded_checkpoint)
167
+ print(f"Load the VAE from {vae_checkpoint_path}, load result: {load_result}")
168
+
169
+ @torch.no_grad()
170
+ def get_pyramid_latent(self, x, stage_num):
171
+ # x is the origin vae latent
172
+ vae_latent_list = []
173
+ vae_latent_list.append(x)
174
+
175
+ temp, height, width = x.shape[-3], x.shape[-2], x.shape[-1]
176
+ for _ in range(stage_num):
177
+ height //= 2
178
+ width //= 2
179
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
180
+ x = torch.nn.functional.interpolate(x, size=(height, width), mode='bilinear')
181
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=temp)
182
+ vae_latent_list.append(x)
183
+
184
+ vae_latent_list = list(reversed(vae_latent_list))
185
+ return vae_latent_list
186
+
187
+ def prepare_latents(
188
+ self,
189
+ batch_size,
190
+ num_channels_latents,
191
+ temp,
192
+ height,
193
+ width,
194
+ dtype,
195
+ device,
196
+ generator,
197
+ ):
198
+ shape = (
199
+ batch_size,
200
+ num_channels_latents,
201
+ int(temp),
202
+ int(height) // self.downsample,
203
+ int(width) // self.downsample,
204
+ )
205
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
206
+ return latents
207
+
208
+ def sample_block_noise(self, bs, ch, temp, height, width):
209
+ gamma = self.scheduler.config.gamma
210
+ dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(4), torch.eye(4) * (1 + gamma) - torch.ones(4, 4) * gamma)
211
+ block_number = bs * ch * temp * (height // 2) * (width // 2)
212
+ noise = torch.stack([dist.sample() for _ in range(block_number)]) # [block number, 4]
213
+ noise = rearrange(noise, '(b c t h w) (p q) -> b c t (h p) (w q)',b=bs,c=ch,t=temp,h=height//2,w=width//2,p=2,q=2)
214
+ return noise
215
+
216
+ @torch.no_grad()
217
+ def generate_one_unit(
218
+ self,
219
+ latents,
220
+ past_conditions, # List of past conditions, contains the conditions of each stage
221
+ prompt_embeds,
222
+ prompt_attention_mask,
223
+ pooled_prompt_embeds,
224
+ num_inference_steps,
225
+ height,
226
+ width,
227
+ temp,
228
+ device,
229
+ dtype,
230
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
231
+ is_first_frame: bool = False,
232
+ ):
233
+ stages = self.stages
234
+ intermed_latents = []
235
+
236
+ for i_s in range(len(stages)):
237
+ self.scheduler.set_timesteps(num_inference_steps[i_s], i_s, device=device)
238
+ timesteps = self.scheduler.timesteps
239
+
240
+ if i_s > 0:
241
+ height *= 2; width *= 2
242
+ latents = rearrange(latents, 'b c t h w -> (b t) c h w')
243
+ latents = F.interpolate(latents, size=(height, width), mode='nearest')
244
+ latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp)
245
+ # Fix the stage
246
+ ori_sigma = 1 - self.scheduler.ori_start_sigmas[i_s] # the original coeff of signal
247
+ gamma = self.scheduler.config.gamma
248
+ alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)
249
+ beta = alpha * (1 - ori_sigma) / math.sqrt(gamma)
250
+
251
+ bs, ch, temp, height, width = latents.shape
252
+ noise = self.sample_block_noise(bs, ch, temp, height, width)
253
+ noise = noise.to(device=device, dtype=dtype)
254
+ latents = alpha * latents + beta * noise # To fix the block artifact
255
+
256
+ for idx, t in enumerate(timesteps):
257
+ # expand the latents if we are doing classifier free guidance
258
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
259
+
260
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
261
+ timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
262
+
263
+ latent_model_input = past_conditions[i_s] + [latent_model_input]
264
+
265
+ noise_pred = self.dit(
266
+ sample=[latent_model_input],
267
+ timestep_ratio=timestep,
268
+ encoder_hidden_states=prompt_embeds,
269
+ encoder_attention_mask=prompt_attention_mask,
270
+ pooled_projections=pooled_prompt_embeds,
271
+ )
272
+
273
+ noise_pred = noise_pred[0]
274
+
275
+ # perform guidance
276
+ if self.do_classifier_free_guidance:
277
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
278
+ if is_first_frame:
279
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
280
+ else:
281
+ noise_pred = noise_pred_uncond + self.video_guidance_scale * (noise_pred_text - noise_pred_uncond)
282
+
283
+ # compute the previous noisy sample x_t -> x_t-1
284
+ latents = self.scheduler.step(
285
+ model_output=noise_pred,
286
+ timestep=timestep,
287
+ sample=latents,
288
+ generator=generator,
289
+ ).prev_sample
290
+
291
+ intermed_latents.append(latents)
292
+
293
+ return intermed_latents
294
+
295
+ @torch.no_grad()
296
+ def generate_i2v(
297
+ self,
298
+ prompt: Union[str, List[str]] = '',
299
+ input_image: PIL.Image = None,
300
+ temp: int = 1,
301
+ num_inference_steps: Optional[Union[int, List[int]]] = 28,
302
+ guidance_scale: float = 7.0,
303
+ video_guidance_scale: float = 4.0,
304
+ min_guidance_scale: float = 2.0,
305
+ use_linear_guidance: bool = False,
306
+ alpha: float = 0.5,
307
+ negative_prompt: Optional[Union[str, List[str]]]="cartoon style, worst quality, low quality, blurry, absolute black, absolute white, low res, extra limbs, extra digits, misplaced objects, mutated anatomy, monochrome, horror",
308
+ num_images_per_prompt: Optional[int] = 1,
309
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
310
+ output_type: Optional[str] = "pil",
311
+ save_memory: bool = True,
312
+ ):
313
+ device = self.device
314
+ dtype = self.dtype
315
+
316
+ width = input_image.width
317
+ height = input_image.height
318
+
319
+ assert temp % self.frame_per_unit == 0, "The frames should be divided by frame_per unit"
320
+
321
+ if isinstance(prompt, str):
322
+ batch_size = 1
323
+ prompt = prompt + ", hyper quality, Ultra HD, 8K" # adding this prompt to improve aesthetics
324
+ else:
325
+ assert isinstance(prompt, list)
326
+ batch_size = len(prompt)
327
+ prompt = [_ + ", hyper quality, Ultra HD, 8K" for _ in prompt]
328
+
329
+ if isinstance(num_inference_steps, int):
330
+ num_inference_steps = [num_inference_steps] * len(self.stages)
331
+
332
+ negative_prompt = negative_prompt or ""
333
+
334
+ # Get the text embeddings
335
+ prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.text_encoder(prompt, device)
336
+ negative_prompt_embeds, negative_prompt_attention_mask, negative_pooled_prompt_embeds = self.text_encoder(negative_prompt, device)
337
+
338
+ if use_linear_guidance:
339
+ max_guidance_scale = guidance_scale
340
+ guidance_scale_list = [max(max_guidance_scale - alpha * t_, min_guidance_scale) for t_ in range(temp+1)]
341
+ print(guidance_scale_list)
342
+
343
+ self._guidance_scale = guidance_scale
344
+ self._video_guidance_scale = video_guidance_scale
345
+
346
+ if self.do_classifier_free_guidance:
347
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
348
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
349
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
350
+
351
+ # Create the initial random noise
352
+ num_channels_latents = self.dit.config.in_channels
353
+ latents = self.prepare_latents(
354
+ batch_size * num_images_per_prompt,
355
+ num_channels_latents,
356
+ temp,
357
+ height,
358
+ width,
359
+ prompt_embeds.dtype,
360
+ device,
361
+ generator,
362
+ )
363
+
364
+ temp, height, width = latents.shape[-3], latents.shape[-2], latents.shape[-1]
365
+
366
+ latents = rearrange(latents, 'b c t h w -> (b t) c h w')
367
+ # by defalut, we needs to start from the block noise
368
+ for _ in range(len(self.stages)-1):
369
+ height //= 2;width //= 2
370
+ latents = F.interpolate(latents, size=(height, width), mode='bilinear') * 2
371
+
372
+ latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp)
373
+
374
+ num_units = temp // self.frame_per_unit
375
+ stages = self.stages
376
+
377
+ # encode the image latents
378
+ image_transform = transforms.Compose([
379
+ transforms.ToTensor(),
380
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
381
+ ])
382
+ input_image_tensor = image_transform(input_image).unsqueeze(0).unsqueeze(2) # [b c 1 h w]
383
+ input_image_latent = (self.vae.encode(input_image_tensor.to(device)).latent_dist.sample() - self.vae_shift_factor) * self.vae_scale_factor # [b c 1 h w]
384
+
385
+ generated_latents_list = [input_image_latent] # The generated results
386
+ last_generated_latents = input_image_latent
387
+
388
+ for unit_index in tqdm(range(1, num_units + 1)):
389
+ if use_linear_guidance:
390
+ self._guidance_scale = guidance_scale_list[unit_index]
391
+ self._video_guidance_scale = guidance_scale_list[unit_index]
392
+
393
+ # prepare the condition latents
394
+ past_condition_latents = []
395
+ clean_latents_list = self.get_pyramid_latent(torch.cat(generated_latents_list, dim=2), len(stages) - 1)
396
+
397
+ for i_s in range(len(stages)):
398
+ last_cond_latent = clean_latents_list[i_s][:,:,-self.frame_per_unit:]
399
+
400
+ stage_input = [torch.cat([last_cond_latent] * 2) if self.do_classifier_free_guidance else last_cond_latent]
401
+
402
+ # pad the past clean latents
403
+ cur_unit_num = unit_index
404
+ cur_stage = i_s
405
+ cur_unit_ptx = 1
406
+
407
+ while cur_unit_ptx < cur_unit_num:
408
+ cur_stage = max(cur_stage - 1, 0)
409
+ if cur_stage == 0:
410
+ break
411
+ cur_unit_ptx += 1
412
+ cond_latents = clean_latents_list[cur_stage][:, :, -(cur_unit_ptx * self.frame_per_unit) : -((cur_unit_ptx - 1) * self.frame_per_unit)]
413
+ stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
414
+
415
+ if cur_stage == 0 and cur_unit_ptx < cur_unit_num:
416
+ cond_latents = clean_latents_list[0][:, :, :-(cur_unit_ptx * self.frame_per_unit)]
417
+ stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
418
+
419
+ stage_input = list(reversed(stage_input))
420
+ past_condition_latents.append(stage_input)
421
+
422
+ intermed_latents = self.generate_one_unit(
423
+ latents[:,:,(unit_index - 1) * self.frame_per_unit:unit_index * self.frame_per_unit],
424
+ past_condition_latents,
425
+ prompt_embeds,
426
+ prompt_attention_mask,
427
+ pooled_prompt_embeds,
428
+ num_inference_steps,
429
+ height,
430
+ width,
431
+ self.frame_per_unit,
432
+ device,
433
+ dtype,
434
+ generator,
435
+ is_first_frame=False,
436
+ )
437
+
438
+ generated_latents_list.append(intermed_latents[-1])
439
+ last_generated_latents = intermed_latents
440
+
441
+ generated_latents = torch.cat(generated_latents_list, dim=2)
442
+
443
+ if output_type == "latent":
444
+ image = generated_latents
445
+ else:
446
+ image = self.decode_latent(generated_latents, save_memory=save_memory)
447
+
448
+ return image
449
+
450
+ @torch.no_grad()
451
+ def generate(
452
+ self,
453
+ prompt: Union[str, List[str]] = None,
454
+ height: Optional[int] = None,
455
+ width: Optional[int] = None,
456
+ temp: int = 1,
457
+ num_inference_steps: Optional[Union[int, List[int]]] = 28,
458
+ video_num_inference_steps: Optional[Union[int, List[int]]] = 28,
459
+ guidance_scale: float = 7.0,
460
+ video_guidance_scale: float = 7.0,
461
+ min_guidance_scale: float = 2.0,
462
+ use_linear_guidance: bool = False,
463
+ alpha: float = 0.5,
464
+ negative_prompt: Optional[Union[str, List[str]]]="cartoon style, worst quality, low quality, blurry, absolute black, absolute white, low res, extra limbs, extra digits, misplaced objects, mutated anatomy, monochrome, horror",
465
+ num_images_per_prompt: Optional[int] = 1,
466
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
467
+ output_type: Optional[str] = "pil",
468
+ save_memory: bool = True,
469
+ ):
470
+ device = self.device
471
+ dtype = self.dtype
472
+
473
+ assert (temp - 1) % self.frame_per_unit == 0, "The frames should be divided by frame_per unit"
474
+
475
+ if isinstance(prompt, str):
476
+ batch_size = 1
477
+ prompt = prompt + ", hyper quality, Ultra HD, 8K" # adding this prompt to improve aesthetics
478
+ else:
479
+ assert isinstance(prompt, list)
480
+ batch_size = len(prompt)
481
+ prompt = [_ + ", hyper quality, Ultra HD, 8K" for _ in prompt]
482
+
483
+ if isinstance(num_inference_steps, int):
484
+ num_inference_steps = [num_inference_steps] * len(self.stages)
485
+
486
+ if isinstance(video_num_inference_steps, int):
487
+ video_num_inference_steps = [video_num_inference_steps] * len(self.stages)
488
+
489
+ negative_prompt = negative_prompt or ""
490
+
491
+ # Get the text embeddings
492
+ prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.text_encoder(prompt, device)
493
+ negative_prompt_embeds, negative_prompt_attention_mask, negative_pooled_prompt_embeds = self.text_encoder(negative_prompt, device)
494
+
495
+ if use_linear_guidance:
496
+ max_guidance_scale = guidance_scale
497
+ # guidance_scale_list = torch.linspace(max_guidance_scale, min_guidance_scale, temp).tolist()
498
+ guidance_scale_list = [max(max_guidance_scale - alpha * t_, min_guidance_scale) for t_ in range(temp)]
499
+ print(guidance_scale_list)
500
+
501
+ self._guidance_scale = guidance_scale
502
+ self._video_guidance_scale = video_guidance_scale
503
+
504
+ if self.do_classifier_free_guidance:
505
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
506
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
507
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
508
+
509
+ # Create the initial random noise
510
+ num_channels_latents = self.dit.config.in_channels
511
+ latents = self.prepare_latents(
512
+ batch_size * num_images_per_prompt,
513
+ num_channels_latents,
514
+ temp,
515
+ height,
516
+ width,
517
+ prompt_embeds.dtype,
518
+ device,
519
+ generator,
520
+ )
521
+
522
+ temp, height, width = latents.shape[-3], latents.shape[-2], latents.shape[-1]
523
+
524
+ latents = rearrange(latents, 'b c t h w -> (b t) c h w')
525
+ # by defalut, we needs to start from the block noise
526
+ for _ in range(len(self.stages)-1):
527
+ height //= 2;width //= 2
528
+ latents = F.interpolate(latents, size=(height, width), mode='bilinear') * 2
529
+
530
+ latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp)
531
+
532
+ num_units = 1 + (temp - 1) // self.frame_per_unit
533
+ stages = self.stages
534
+
535
+ generated_latents_list = [] # The generated results
536
+ last_generated_latents = None
537
+
538
+ for unit_index in tqdm(range(num_units)):
539
+ if use_linear_guidance:
540
+ self._guidance_scale = guidance_scale_list[unit_index]
541
+ self._video_guidance_scale = guidance_scale_list[unit_index]
542
+
543
+ if unit_index == 0:
544
+ past_condition_latents = [[] for _ in range(len(stages))]
545
+ intermed_latents = self.generate_one_unit(
546
+ latents[:,:,:1],
547
+ past_condition_latents,
548
+ prompt_embeds,
549
+ prompt_attention_mask,
550
+ pooled_prompt_embeds,
551
+ num_inference_steps,
552
+ height,
553
+ width,
554
+ 1,
555
+ device,
556
+ dtype,
557
+ generator,
558
+ is_first_frame=True,
559
+ )
560
+ else:
561
+ # prepare the condition latents
562
+ past_condition_latents = []
563
+ clean_latents_list = self.get_pyramid_latent(torch.cat(generated_latents_list, dim=2), len(stages) - 1)
564
+
565
+ for i_s in range(len(stages)):
566
+ last_cond_latent = clean_latents_list[i_s][:,:,-(self.frame_per_unit):]
567
+
568
+ stage_input = [torch.cat([last_cond_latent] * 2) if self.do_classifier_free_guidance else last_cond_latent]
569
+
570
+ # pad the past clean latents
571
+ cur_unit_num = unit_index
572
+ cur_stage = i_s
573
+ cur_unit_ptx = 1
574
+
575
+ while cur_unit_ptx < cur_unit_num:
576
+ cur_stage = max(cur_stage - 1, 0)
577
+ if cur_stage == 0:
578
+ break
579
+ cur_unit_ptx += 1
580
+ cond_latents = clean_latents_list[cur_stage][:, :, -(cur_unit_ptx * self.frame_per_unit) : -((cur_unit_ptx - 1) * self.frame_per_unit)]
581
+ stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
582
+
583
+ if cur_stage == 0 and cur_unit_ptx < cur_unit_num:
584
+ cond_latents = clean_latents_list[0][:, :, :-(cur_unit_ptx * self.frame_per_unit)]
585
+ stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
586
+
587
+ stage_input = list(reversed(stage_input))
588
+ past_condition_latents.append(stage_input)
589
+
590
+ intermed_latents = self.generate_one_unit(
591
+ latents[:,:, 1 + (unit_index - 1) * self.frame_per_unit:1 + unit_index * self.frame_per_unit],
592
+ past_condition_latents,
593
+ prompt_embeds,
594
+ prompt_attention_mask,
595
+ pooled_prompt_embeds,
596
+ video_num_inference_steps,
597
+ height,
598
+ width,
599
+ self.frame_per_unit,
600
+ device,
601
+ dtype,
602
+ generator,
603
+ is_first_frame=False,
604
+ )
605
+
606
+ generated_latents_list.append(intermed_latents[-1])
607
+ last_generated_latents = intermed_latents
608
+
609
+ generated_latents = torch.cat(generated_latents_list, dim=2)
610
+
611
+ if output_type == "latent":
612
+ image = generated_latents
613
+ else:
614
+ image = self.decode_latent(generated_latents, save_memory=save_memory)
615
+
616
+ return image
617
+
618
+ def decode_latent(self, latents, save_memory=True):
619
+ if latents.shape[2] == 1:
620
+ latents = (latents / self.vae_scale_factor) + self.vae_shift_factor
621
+ else:
622
+ latents[:, :, :1] = (latents[:, :, :1] / self.vae_scale_factor) + self.vae_shift_factor
623
+ latents[:, :, 1:] = (latents[:, :, 1:] / self.vae_video_scale_factor) + self.vae_video_shift_factor
624
+
625
+ if save_memory:
626
+ # reducing the tile size and temporal chunk window size
627
+ image = self.vae.decode(latents, temporal_chunk=True, window_size=1, tile_sample_min_size=256).sample
628
+ else:
629
+ image = self.vae.decode(latents, temporal_chunk=True, window_size=2, tile_sample_min_size=512).sample
630
+
631
+ image = image.float()
632
+ image = (image / 2 + 0.5).clamp(0, 1)
633
+ image = rearrange(image, "B C T H W -> (B T) C H W")
634
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
635
+ image = self.numpy_to_pil(image)
636
+ return image
637
+
638
+ @staticmethod
639
+ def numpy_to_pil(images):
640
+ """
641
+ Convert a numpy image or a batch of images to a PIL image.
642
+ """
643
+ if images.ndim == 3:
644
+ images = images[None, ...]
645
+ images = (images * 255).round().astype("uint8")
646
+ if images.shape[-1] == 1:
647
+ # special case for grayscale (single channel) images
648
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
649
+ else:
650
+ pil_images = [Image.fromarray(image) for image in images]
651
+
652
+ return pil_images
653
+
654
+ @property
655
+ def device(self):
656
+ return next(self.dit.parameters()).device
657
+
658
+ @property
659
+ def dtype(self):
660
+ return next(self.dit.parameters()).dtype
661
+
662
+ @property
663
+ def guidance_scale(self):
664
+ return self._guidance_scale
665
+
666
+ @property
667
+ def video_guidance_scale(self):
668
+ return self._video_guidance_scale
669
+
670
+ @property
671
+ def do_classifier_free_guidance(self):
672
+ return self._guidance_scale > 0
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sentencepiece
2
+ tiktoken
3
+ jsonlines
4
+ diffusers==0.30.1
5
+ accelerate==0.30.0
6
+ torchvision
7
+ numpy==1.26.4
8
+ imageio
9
+ imageio-ffmpeg
10
+ timm
11
+ transformers
12
+ opencv-python-headless
13
+ einops
14
+ tensorboardX
15
+ ipython
trainer_misc/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import (
2
+ create_optimizer,
3
+ get_rank,
4
+ get_world_size,
5
+ is_main_process,
6
+ is_dist_avail_and_initialized,
7
+ init_distributed_mode,
8
+ setup_for_distributed,
9
+ cosine_scheduler,
10
+ constant_scheduler,
11
+ )
12
+
13
+ from .sp_utils import (
14
+ is_sequence_parallel_initialized,
15
+ init_sequence_parallel_group,
16
+ get_sequence_parallel_group,
17
+ get_sequence_parallel_world_size,
18
+ get_sequence_parallel_rank,
19
+ get_sequence_parallel_group_rank,
20
+ get_sequence_parallel_proc_num,
21
+ init_sync_input_group,
22
+ get_sync_input_group,
23
+ )
24
+
25
+ from .communicate import all_to_all
trainer_misc/communicate.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ import torch.distributed as dist
5
+
6
+
7
+ def _all_to_all(
8
+ input_: torch.Tensor,
9
+ world_size: int,
10
+ group: dist.ProcessGroup,
11
+ scatter_dim: int,
12
+ gather_dim: int,
13
+ ):
14
+ if world_size == 1:
15
+ return input_
16
+ input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
17
+ output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
18
+ dist.all_to_all(output_list, input_list, group=group)
19
+ return torch.cat(output_list, dim=gather_dim).contiguous()
20
+
21
+
22
+ class _AllToAll(torch.autograd.Function):
23
+
24
+ @staticmethod
25
+ def forward(ctx, input_, process_group, world_size, scatter_dim, gather_dim):
26
+ ctx.process_group = process_group
27
+ ctx.scatter_dim = scatter_dim
28
+ ctx.gather_dim = gather_dim
29
+ ctx.world_size = world_size
30
+ output = _all_to_all(input_, ctx.world_size, process_group, scatter_dim, gather_dim)
31
+ return output
32
+
33
+ @staticmethod
34
+ def backward(ctx, grad_output):
35
+ grad_output = _all_to_all(
36
+ grad_output,
37
+ ctx.world_size,
38
+ ctx.process_group,
39
+ ctx.gather_dim,
40
+ ctx.scatter_dim,
41
+ )
42
+ return (
43
+ grad_output,
44
+ None,
45
+ None,
46
+ None,
47
+ None,
48
+ )
49
+
50
+
51
+ def all_to_all(
52
+ input_: torch.Tensor,
53
+ process_group: dist.ProcessGroup,
54
+ world_size: int = 1,
55
+ scatter_dim: int = 2,
56
+ gather_dim: int = 1,
57
+ ):
58
+ return _AllToAll.apply(input_, process_group, world_size, scatter_dim, gather_dim)
trainer_misc/sp_utils.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.distributed as dist
4
+ from .utils import is_dist_avail_and_initialized, get_rank
5
+
6
+
7
+ SEQ_PARALLEL_GROUP = None
8
+ SEQ_PARALLEL_SIZE = None
9
+ SEQ_PARALLEL_PROC_NUM = None # using how many process for sequence parallel
10
+
11
+ SYNC_INPUT_GROUP = None
12
+ SYNC_INPUT_SIZE = None
13
+
14
+ def is_sequence_parallel_initialized():
15
+ if SEQ_PARALLEL_GROUP is None:
16
+ return False
17
+ else:
18
+ return True
19
+
20
+
21
+ def init_sequence_parallel_group(args):
22
+ global SEQ_PARALLEL_GROUP
23
+ global SEQ_PARALLEL_SIZE
24
+ global SEQ_PARALLEL_PROC_NUM
25
+
26
+ assert SEQ_PARALLEL_GROUP is None, "sequence parallel group is already initialized"
27
+ assert is_dist_avail_and_initialized(), "The pytorch distributed should be initialized"
28
+ SEQ_PARALLEL_SIZE = args.sp_group_size
29
+
30
+ print(f"Setting the Sequence Parallel Size {SEQ_PARALLEL_SIZE}")
31
+
32
+ rank = torch.distributed.get_rank()
33
+ world_size = torch.distributed.get_world_size()
34
+
35
+ if args.sp_proc_num == -1:
36
+ SEQ_PARALLEL_PROC_NUM = world_size
37
+ else:
38
+ SEQ_PARALLEL_PROC_NUM = args.sp_proc_num
39
+
40
+ assert SEQ_PARALLEL_PROC_NUM % SEQ_PARALLEL_SIZE == 0, "The process needs to be evenly divided"
41
+
42
+ for i in range(0, SEQ_PARALLEL_PROC_NUM, SEQ_PARALLEL_SIZE):
43
+ ranks = list(range(i, i + SEQ_PARALLEL_SIZE))
44
+ group = torch.distributed.new_group(ranks)
45
+ if rank in ranks:
46
+ SEQ_PARALLEL_GROUP = group
47
+ break
48
+
49
+
50
+ def init_sync_input_group(args):
51
+ global SYNC_INPUT_GROUP
52
+ global SYNC_INPUT_SIZE
53
+
54
+ assert SYNC_INPUT_GROUP is None, "parallel group is already initialized"
55
+ assert is_dist_avail_and_initialized(), "The pytorch distributed should be initialized"
56
+ SYNC_INPUT_SIZE = args.max_frames
57
+
58
+ rank = torch.distributed.get_rank()
59
+ world_size = torch.distributed.get_world_size()
60
+
61
+ for i in range(0, world_size, SYNC_INPUT_SIZE):
62
+ ranks = list(range(i, i + SYNC_INPUT_SIZE))
63
+ group = torch.distributed.new_group(ranks)
64
+ if rank in ranks:
65
+ SYNC_INPUT_GROUP = group
66
+ break
67
+
68
+
69
+ def get_sequence_parallel_group():
70
+ assert SEQ_PARALLEL_GROUP is not None, "sequence parallel group is not initialized"
71
+ return SEQ_PARALLEL_GROUP
72
+
73
+
74
+ def get_sync_input_group():
75
+ return SYNC_INPUT_GROUP
76
+
77
+
78
+ def get_sequence_parallel_world_size():
79
+ assert SEQ_PARALLEL_SIZE is not None, "sequence parallel size is not initialized"
80
+ return SEQ_PARALLEL_SIZE
81
+
82
+
83
+ def get_sequence_parallel_rank():
84
+ assert SEQ_PARALLEL_SIZE is not None, "sequence parallel size is not initialized"
85
+ rank = get_rank()
86
+ cp_rank = rank % SEQ_PARALLEL_SIZE
87
+ return cp_rank
88
+
89
+
90
+ def get_sequence_parallel_group_rank():
91
+ assert SEQ_PARALLEL_SIZE is not None, "sequence parallel size is not initialized"
92
+ rank = get_rank()
93
+ cp_group_rank = rank // SEQ_PARALLEL_SIZE
94
+ return cp_group_rank
95
+
96
+
97
+ def get_sequence_parallel_proc_num():
98
+ return SEQ_PARALLEL_PROC_NUM
trainer_misc/utils.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import math
4
+ import time
5
+ import json
6
+ import glob
7
+ from collections import defaultdict, deque, OrderedDict
8
+ import datetime
9
+ import numpy as np
10
+
11
+
12
+ from pathlib import Path
13
+ import argparse
14
+
15
+ import torch
16
+ from torch import optim as optim
17
+ import torch.distributed as dist
18
+ from tensorboardX import SummaryWriter
19
+
20
+
21
+ def is_dist_avail_and_initialized():
22
+ if not dist.is_available():
23
+ return False
24
+ if not dist.is_initialized():
25
+ return False
26
+ return True
27
+
28
+
29
+ def get_world_size():
30
+ if not is_dist_avail_and_initialized():
31
+ return 1
32
+ return dist.get_world_size()
33
+
34
+
35
+ def get_rank():
36
+ if not is_dist_avail_and_initialized():
37
+ return 0
38
+ return dist.get_rank()
39
+
40
+
41
+ def is_main_process():
42
+ return get_rank() == 0
43
+
44
+
45
+ def save_on_master(*args, **kwargs):
46
+ if is_main_process():
47
+ torch.save(*args, **kwargs)
48
+
49
+
50
+ def setup_for_distributed(is_master):
51
+ """
52
+ This function disables printing when not in master process
53
+ """
54
+ import builtins as __builtin__
55
+ builtin_print = __builtin__.print
56
+
57
+ def print(*args, **kwargs):
58
+ force = kwargs.pop('force', False)
59
+ if is_master or force:
60
+ builtin_print(*args, **kwargs)
61
+
62
+ __builtin__.print = print
63
+
64
+
65
+ def init_distributed_mode(args):
66
+ if int(os.getenv('OMPI_COMM_WORLD_SIZE', '0')) > 0:
67
+ rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
68
+ local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
69
+ world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
70
+
71
+ os.environ["LOCAL_RANK"] = os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
72
+ os.environ["RANK"] = os.environ['OMPI_COMM_WORLD_RANK']
73
+ os.environ["WORLD_SIZE"] = os.environ['OMPI_COMM_WORLD_SIZE']
74
+
75
+ args.rank = int(os.environ["RANK"])
76
+ args.world_size = int(os.environ["WORLD_SIZE"])
77
+ args.gpu = int(os.environ["LOCAL_RANK"])
78
+
79
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
80
+ args.rank = int(os.environ["RANK"])
81
+ args.world_size = int(os.environ['WORLD_SIZE'])
82
+ args.gpu = int(os.environ['LOCAL_RANK'])
83
+
84
+ else:
85
+ print('Not using distributed mode')
86
+ args.distributed = False
87
+ return
88
+
89
+ args.distributed = True
90
+ args.dist_backend = 'nccl'
91
+ args.dist_url = "env://"
92
+ print('| distributed init (rank {}): {}, gpu {}'.format(
93
+ args.rank, args.dist_url, args.gpu), flush=True)
94
+
95
+
96
+ def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
97
+ start_warmup_value=0, warmup_steps=-1):
98
+ warmup_schedule = np.array([])
99
+ warmup_iters = warmup_epochs * niter_per_ep
100
+ if warmup_steps > 0:
101
+ warmup_iters = warmup_steps
102
+ print("Set warmup steps = %d" % warmup_iters)
103
+ if warmup_epochs > 0:
104
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
105
+
106
+ iters = np.arange(epochs * niter_per_ep - warmup_iters)
107
+ schedule = np.array(
108
+ [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
109
+
110
+ schedule = np.concatenate((warmup_schedule, schedule))
111
+
112
+ assert len(schedule) == epochs * niter_per_ep
113
+ return schedule
114
+
115
+
116
+ def constant_scheduler(base_value, epochs, niter_per_ep, warmup_epochs=0,
117
+ start_warmup_value=1e-6, warmup_steps=-1):
118
+ warmup_schedule = np.array([])
119
+ warmup_iters = warmup_epochs * niter_per_ep
120
+ if warmup_steps > 0:
121
+ warmup_iters = warmup_steps
122
+ print("Set warmup steps = %d" % warmup_iters)
123
+ if warmup_iters > 0:
124
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
125
+
126
+ iters = epochs * niter_per_ep - warmup_iters
127
+ schedule = np.array([base_value] * iters)
128
+
129
+ schedule = np.concatenate((warmup_schedule, schedule))
130
+
131
+ assert len(schedule) == epochs * niter_per_ep
132
+ return schedule
133
+
134
+
135
+ def get_parameter_groups(model, weight_decay=1e-5, base_lr=1e-4, skip_list=(), get_num_layer=None, get_layer_scale=None, **kwargs):
136
+ parameter_group_names = {}
137
+ parameter_group_vars = {}
138
+
139
+ for name, param in model.named_parameters():
140
+ if not param.requires_grad:
141
+ continue # frozen weights
142
+ if len(kwargs.get('filter_name', [])) > 0:
143
+ flag = False
144
+ for filter_n in kwargs.get('filter_name', []):
145
+ if filter_n in name:
146
+ print(f"filter {name} because of the pattern {filter_n}")
147
+ flag = True
148
+ if flag:
149
+ continue
150
+
151
+ default_scale=1.
152
+
153
+ if param.ndim <= 1 or name.endswith(".bias") or name in skip_list: # param.ndim <= 1 len(param.shape) == 1
154
+ group_name = "no_decay"
155
+ this_weight_decay = 0.
156
+ else:
157
+ group_name = "decay"
158
+ this_weight_decay = weight_decay
159
+
160
+ if get_num_layer is not None:
161
+ layer_id = get_num_layer(name)
162
+ group_name = "layer_%d_%s" % (layer_id, group_name)
163
+ else:
164
+ layer_id = None
165
+
166
+ if group_name not in parameter_group_names:
167
+ if get_layer_scale is not None:
168
+ scale = get_layer_scale(layer_id)
169
+ else:
170
+ scale = default_scale
171
+
172
+ parameter_group_names[group_name] = {
173
+ "weight_decay": this_weight_decay,
174
+ "params": [],
175
+ "lr": base_lr,
176
+ "lr_scale": scale,
177
+ }
178
+
179
+ parameter_group_vars[group_name] = {
180
+ "weight_decay": this_weight_decay,
181
+ "params": [],
182
+ "lr": base_lr,
183
+ "lr_scale": scale,
184
+ }
185
+
186
+ parameter_group_vars[group_name]["params"].append(param)
187
+ parameter_group_names[group_name]["params"].append(name)
188
+
189
+ print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
190
+ return list(parameter_group_vars.values())
191
+
192
+
193
+ def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None, **kwargs):
194
+ opt_lower = args.opt.lower()
195
+ weight_decay = args.weight_decay
196
+
197
+ skip = {}
198
+ if skip_list is not None:
199
+ skip = skip_list
200
+ elif hasattr(model, 'no_weight_decay'):
201
+ skip = model.no_weight_decay()
202
+ print(f"Skip weight decay name marked in model: {skip}")
203
+ parameters = get_parameter_groups(model, weight_decay, args.lr, skip, get_num_layer, get_layer_scale, **kwargs)
204
+ weight_decay = 0.
205
+
206
+ if 'fused' in opt_lower:
207
+ assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
208
+
209
+ opt_args = dict(lr=args.lr, weight_decay=weight_decay)
210
+ if hasattr(args, 'opt_eps') and args.opt_eps is not None:
211
+ opt_args['eps'] = args.opt_eps
212
+ if hasattr(args, 'opt_beta1') and args.opt_beta1 is not None:
213
+ opt_args['betas'] = (args.opt_beta1, args.opt_beta2)
214
+
215
+ print('Optimizer config:', opt_args)
216
+ opt_split = opt_lower.split('_')
217
+ opt_lower = opt_split[-1]
218
+ if opt_lower == 'sgd' or opt_lower == 'nesterov':
219
+ opt_args.pop('eps', None)
220
+ optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
221
+ elif opt_lower == 'momentum':
222
+ opt_args.pop('eps', None)
223
+ optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
224
+ elif opt_lower == 'adam':
225
+ optimizer = optim.Adam(parameters, **opt_args)
226
+ elif opt_lower == 'adamw':
227
+ optimizer = optim.AdamW(parameters, **opt_args)
228
+ elif opt_lower == 'adadelta':
229
+ optimizer = optim.Adadelta(parameters, **opt_args)
230
+ elif opt_lower == 'rmsprop':
231
+ optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
232
+ else:
233
+ assert False and "Invalid optimizer"
234
+ raise ValueError
235
+
236
+ return optimizer
237
+
238
+
239
+ class SmoothedValue(object):
240
+ """Track a series of values and provide access to smoothed values over a
241
+ window or the global series average.
242
+ """
243
+
244
+ def __init__(self, window_size=20, fmt=None):
245
+ if fmt is None:
246
+ fmt = "{median:.4f} ({global_avg:.4f})"
247
+ self.deque = deque(maxlen=window_size)
248
+ self.total = 0.0
249
+ self.count = 0
250
+ self.fmt = fmt
251
+
252
+ def update(self, value, n=1):
253
+ self.deque.append(value)
254
+ self.count += n
255
+ self.total += value * n
256
+
257
+ def synchronize_between_processes(self):
258
+ """
259
+ Warning: does not synchronize the deque!
260
+ """
261
+ if not is_dist_avail_and_initialized():
262
+ return
263
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
264
+ dist.barrier()
265
+ dist.all_reduce(t)
266
+ t = t.tolist()
267
+ self.count = int(t[0])
268
+ self.total = t[1]
269
+
270
+ @property
271
+ def median(self):
272
+ d = torch.tensor(list(self.deque))
273
+ return d.median().item()
274
+
275
+ @property
276
+ def avg(self):
277
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
278
+ return d.mean().item()
279
+
280
+ @property
281
+ def global_avg(self):
282
+ return self.total / self.count
283
+
284
+ @property
285
+ def max(self):
286
+ return max(self.deque)
287
+
288
+ @property
289
+ def value(self):
290
+ return self.deque[-1]
291
+
292
+ def __str__(self):
293
+ return self.fmt.format(
294
+ median=self.median,
295
+ avg=self.avg,
296
+ global_avg=self.global_avg,
297
+ max=self.max,
298
+ value=self.value)
299
+
300
+
301
+ class MetricLogger(object):
302
+ def __init__(self, delimiter="\t"):
303
+ self.meters = defaultdict(SmoothedValue)
304
+ self.delimiter = delimiter
305
+
306
+ def update(self, **kwargs):
307
+ for k, v in kwargs.items():
308
+ if v is None:
309
+ continue
310
+ if isinstance(v, torch.Tensor):
311
+ v = v.item()
312
+ assert isinstance(v, (float, int))
313
+ self.meters[k].update(v)
314
+
315
+ def __getattr__(self, attr):
316
+ if attr in self.meters:
317
+ return self.meters[attr]
318
+ if attr in self.__dict__:
319
+ return self.__dict__[attr]
320
+ raise AttributeError("'{}' object has no attribute '{}'".format(
321
+ type(self).__name__, attr))
322
+
323
+ def __str__(self):
324
+ loss_str = []
325
+ for name, meter in self.meters.items():
326
+ loss_str.append(
327
+ "{}: {}".format(name, str(meter))
328
+ )
329
+ return self.delimiter.join(loss_str)
330
+
331
+ def synchronize_between_processes(self):
332
+ for meter in self.meters.values():
333
+ meter.synchronize_between_processes()
334
+
335
+ def add_meter(self, name, meter):
336
+ self.meters[name] = meter
337
+
338
+ def log_every(self, iterable, print_freq, header=None):
339
+ i = 0
340
+ if not header:
341
+ header = ''
342
+ start_time = time.time()
343
+ end = time.time()
344
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
345
+ data_time = SmoothedValue(fmt='{avg:.4f}')
346
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
347
+ log_msg = [
348
+ header,
349
+ '[{0' + space_fmt + '}/{1}]',
350
+ 'eta: {eta}',
351
+ '{meters}',
352
+ 'time: {time}',
353
+ 'data: {data}'
354
+ ]
355
+ if torch.cuda.is_available():
356
+ log_msg.append('max mem: {memory:.0f}')
357
+ log_msg = self.delimiter.join(log_msg)
358
+ MB = 1024.0 * 1024.0
359
+ for obj in iterable:
360
+ data_time.update(time.time() - end)
361
+ yield obj
362
+ iter_time.update(time.time() - end)
363
+ if i % print_freq == 0 or i == len(iterable) - 1:
364
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
365
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
366
+ if torch.cuda.is_available():
367
+ print(log_msg.format(
368
+ i, len(iterable), eta=eta_string,
369
+ meters=str(self),
370
+ time=str(iter_time), data=str(data_time),
371
+ memory=torch.cuda.max_memory_allocated() / MB))
372
+ else:
373
+ print(log_msg.format(
374
+ i, len(iterable), eta=eta_string,
375
+ meters=str(self),
376
+ time=str(iter_time), data=str(data_time)))
377
+ i += 1
378
+ end = time.time()
379
+ total_time = time.time() - start_time
380
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
381
+ print('{} Total time: {} ({:.4f} s / it)'.format(
382
+ header, total_time_str, total_time / len(iterable)))
utils.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import PIL.Image
4
+ import numpy as np
5
+ from torch import nn
6
+ import torch.distributed as dist
7
+ import timm.models.hub as timm_hub
8
+
9
+ """Modified from https://github.com/CompVis/taming-transformers.git"""
10
+
11
+ import hashlib
12
+ import requests
13
+ from tqdm import tqdm
14
+ try:
15
+ import piq
16
+ except:
17
+ pass
18
+
19
+ _CONTEXT_PARALLEL_GROUP = None
20
+ _CONTEXT_PARALLEL_SIZE = None
21
+
22
+
23
+ def is_dist_avail_and_initialized():
24
+ if not dist.is_available():
25
+ return False
26
+ if not dist.is_initialized():
27
+ return False
28
+ return True
29
+
30
+
31
+ def get_world_size():
32
+ if not is_dist_avail_and_initialized():
33
+ return 1
34
+ return dist.get_world_size()
35
+
36
+
37
+ def get_rank():
38
+ if not is_dist_avail_and_initialized():
39
+ return 0
40
+ return dist.get_rank()
41
+
42
+
43
+ def is_main_process():
44
+ return get_rank() == 0
45
+
46
+
47
+ def is_context_parallel_initialized():
48
+ if _CONTEXT_PARALLEL_GROUP is None:
49
+ return False
50
+ else:
51
+ return True
52
+
53
+
54
+ def set_context_parallel_group(size, group):
55
+ global _CONTEXT_PARALLEL_GROUP
56
+ global _CONTEXT_PARALLEL_SIZE
57
+ _CONTEXT_PARALLEL_GROUP = group
58
+ _CONTEXT_PARALLEL_SIZE = size
59
+
60
+
61
+ def initialize_context_parallel(context_parallel_size):
62
+ global _CONTEXT_PARALLEL_GROUP
63
+ global _CONTEXT_PARALLEL_SIZE
64
+
65
+ assert _CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized"
66
+ _CONTEXT_PARALLEL_SIZE = context_parallel_size
67
+
68
+ rank = torch.distributed.get_rank()
69
+ world_size = torch.distributed.get_world_size()
70
+
71
+ for i in range(0, world_size, context_parallel_size):
72
+ ranks = range(i, i + context_parallel_size)
73
+ group = torch.distributed.new_group(ranks)
74
+ if rank in ranks:
75
+ _CONTEXT_PARALLEL_GROUP = group
76
+ break
77
+
78
+
79
+ def get_context_parallel_group():
80
+ assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized"
81
+
82
+ return _CONTEXT_PARALLEL_GROUP
83
+
84
+
85
+ def get_context_parallel_world_size():
86
+ assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
87
+
88
+ return _CONTEXT_PARALLEL_SIZE
89
+
90
+
91
+ def get_context_parallel_rank():
92
+ assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
93
+
94
+ rank = get_rank()
95
+ cp_rank = rank % _CONTEXT_PARALLEL_SIZE
96
+ return cp_rank
97
+
98
+
99
+ def get_context_parallel_group_rank():
100
+ assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
101
+
102
+ rank = get_rank()
103
+ cp_group_rank = rank // _CONTEXT_PARALLEL_SIZE
104
+
105
+ return cp_group_rank
106
+
107
+
108
+ def download_cached_file(url, check_hash=True, progress=False):
109
+ """
110
+ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
111
+ If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
112
+ """
113
+
114
+ def get_cached_file_path():
115
+ # a hack to sync the file path across processes
116
+ parts = torch.hub.urlparse(url)
117
+ filename = os.path.basename(parts.path)
118
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
119
+
120
+ return cached_file
121
+
122
+ if is_main_process():
123
+ timm_hub.download_cached_file(url, check_hash, progress)
124
+
125
+ if is_dist_avail_and_initialized():
126
+ dist.barrier()
127
+
128
+ return get_cached_file_path()
129
+
130
+
131
+ def convert_weights_to_fp16(model: nn.Module):
132
+ """Convert applicable model parameters to fp16"""
133
+
134
+ def _convert_weights_to_fp16(l):
135
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
136
+ l.weight.data = l.weight.data.to(torch.float16)
137
+ if l.bias is not None:
138
+ l.bias.data = l.bias.data.to(torch.float16)
139
+
140
+ model.apply(_convert_weights_to_fp16)
141
+
142
+
143
+ def convert_weights_to_bf16(model: nn.Module):
144
+ """Convert applicable model parameters to fp16"""
145
+
146
+ def _convert_weights_to_bf16(l):
147
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
148
+ l.weight.data = l.weight.data.to(torch.bfloat16)
149
+ if l.bias is not None:
150
+ l.bias.data = l.bias.data.to(torch.bfloat16)
151
+
152
+ model.apply(_convert_weights_to_bf16)
153
+
154
+
155
+ def save_result(result, result_dir, filename, remove_duplicate="", save_format='json'):
156
+ import json
157
+ import jsonlines
158
+ print("Dump result")
159
+
160
+ # Make the temp dir for saving results
161
+ if not os.path.exists(result_dir):
162
+ if is_main_process():
163
+ os.makedirs(result_dir)
164
+ if is_dist_avail_and_initialized():
165
+ torch.distributed.barrier()
166
+
167
+ result_file = os.path.join(
168
+ result_dir, "%s_rank%d.json" % (filename, get_rank())
169
+ )
170
+
171
+ final_result_file = os.path.join(result_dir, f"{filename}.{save_format}")
172
+
173
+ json.dump(result, open(result_file, "w"))
174
+
175
+ if is_dist_avail_and_initialized():
176
+ torch.distributed.barrier()
177
+
178
+ if is_main_process():
179
+ # print("rank %d starts merging results." % get_rank())
180
+ # combine results from all processes
181
+ result = []
182
+
183
+ for rank in range(get_world_size()):
184
+ result_file = os.path.join(result_dir, "%s_rank%d.json" % (filename, rank))
185
+ res = json.load(open(result_file, "r"))
186
+ result += res
187
+
188
+ # print("Remove duplicate")
189
+ if remove_duplicate:
190
+ result_new = []
191
+ id_set = set()
192
+ for res in result:
193
+ if res[remove_duplicate] not in id_set:
194
+ id_set.add(res[remove_duplicate])
195
+ result_new.append(res)
196
+ result = result_new
197
+
198
+ if save_format == 'json':
199
+ json.dump(result, open(final_result_file, "w"))
200
+ else:
201
+ assert save_format == 'jsonl', "Only support json adn jsonl format"
202
+ with jsonlines.open(final_result_file, "w") as writer:
203
+ writer.write_all(result)
204
+
205
+ # print("result file saved to %s" % final_result_file)
206
+
207
+ return final_result_file
208
+
209
+
210
+ # resizing utils
211
+ # TODO: clean up later
212
+ def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
213
+ h, w = input.shape[-2:]
214
+ factors = (h / size[0], w / size[1])
215
+
216
+ # First, we have to determine sigma
217
+ # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
218
+ sigmas = (
219
+ max((factors[0] - 1.0) / 2.0, 0.001),
220
+ max((factors[1] - 1.0) / 2.0, 0.001),
221
+ )
222
+
223
+ # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
224
+ # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
225
+ # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
226
+ ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
227
+
228
+ # Make sure it is odd
229
+ if (ks[0] % 2) == 0:
230
+ ks = ks[0] + 1, ks[1]
231
+
232
+ if (ks[1] % 2) == 0:
233
+ ks = ks[0], ks[1] + 1
234
+
235
+ input = _gaussian_blur2d(input, ks, sigmas)
236
+
237
+ output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
238
+ return output
239
+
240
+
241
+ def _compute_padding(kernel_size):
242
+ """Compute padding tuple."""
243
+ # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
244
+ # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
245
+ if len(kernel_size) < 2:
246
+ raise AssertionError(kernel_size)
247
+ computed = [k - 1 for k in kernel_size]
248
+
249
+ # for even kernels we need to do asymmetric padding :(
250
+ out_padding = 2 * len(kernel_size) * [0]
251
+
252
+ for i in range(len(kernel_size)):
253
+ computed_tmp = computed[-(i + 1)]
254
+
255
+ pad_front = computed_tmp // 2
256
+ pad_rear = computed_tmp - pad_front
257
+
258
+ out_padding[2 * i + 0] = pad_front
259
+ out_padding[2 * i + 1] = pad_rear
260
+
261
+ return out_padding
262
+
263
+
264
+ def _filter2d(input, kernel):
265
+ # prepare kernel
266
+ b, c, h, w = input.shape
267
+ tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
268
+
269
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
270
+
271
+ height, width = tmp_kernel.shape[-2:]
272
+
273
+ padding_shape: list[int] = _compute_padding([height, width])
274
+ input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
275
+
276
+ # kernel and input tensor reshape to align element-wise or batch-wise params
277
+ tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
278
+ input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
279
+
280
+ # convolve the tensor with the kernel.
281
+ output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
282
+
283
+ out = output.view(b, c, h, w)
284
+ return out
285
+
286
+
287
+ def _gaussian(window_size: int, sigma):
288
+ if isinstance(sigma, float):
289
+ sigma = torch.tensor([[sigma]])
290
+
291
+ batch_size = sigma.shape[0]
292
+
293
+ x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
294
+
295
+ if window_size % 2 == 0:
296
+ x = x + 0.5
297
+
298
+ gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
299
+
300
+ return gauss / gauss.sum(-1, keepdim=True)
301
+
302
+
303
+ def _gaussian_blur2d(input, kernel_size, sigma):
304
+ if isinstance(sigma, tuple):
305
+ sigma = torch.tensor([sigma], dtype=input.dtype)
306
+ else:
307
+ sigma = sigma.to(dtype=input.dtype)
308
+
309
+ ky, kx = int(kernel_size[0]), int(kernel_size[1])
310
+ bs = sigma.shape[0]
311
+ kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
312
+ kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
313
+ out_x = _filter2d(input, kernel_x[..., None, :])
314
+ out = _filter2d(out_x, kernel_y[..., None])
315
+
316
+ return out
317
+
318
+
319
+ URL_MAP = {
320
+ "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
321
+ }
322
+
323
+ CKPT_MAP = {
324
+ "vgg_lpips": "vgg.pth"
325
+ }
326
+
327
+ MD5_MAP = {
328
+ "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
329
+ }
330
+
331
+
332
+ def download(url, local_path, chunk_size=1024):
333
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
334
+ with requests.get(url, stream=True) as r:
335
+ total_size = int(r.headers.get("content-length", 0))
336
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
337
+ with open(local_path, "wb") as f:
338
+ for data in r.iter_content(chunk_size=chunk_size):
339
+ if data:
340
+ f.write(data)
341
+ pbar.update(chunk_size)
342
+
343
+
344
+ def md5_hash(path):
345
+ with open(path, "rb") as f:
346
+ content = f.read()
347
+ return hashlib.md5(content).hexdigest()
348
+
349
+
350
+ def get_ckpt_path(name, root, check=False):
351
+ assert name in URL_MAP
352
+ path = os.path.join(root, CKPT_MAP[name])
353
+ print(md5_hash(path))
354
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
355
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
356
+ download(URL_MAP[name], path)
357
+ md5 = md5_hash(path)
358
+ assert md5 == MD5_MAP[name], md5
359
+ return path
360
+
361
+
362
+ class KeyNotFoundError(Exception):
363
+ def __init__(self, cause, keys=None, visited=None):
364
+ self.cause = cause
365
+ self.keys = keys
366
+ self.visited = visited
367
+ messages = list()
368
+ if keys is not None:
369
+ messages.append("Key not found: {}".format(keys))
370
+ if visited is not None:
371
+ messages.append("Visited: {}".format(visited))
372
+ messages.append("Cause:\n{}".format(cause))
373
+ message = "\n".join(messages)
374
+ super().__init__(message)
375
+
376
+
377
+ def retrieve(
378
+ list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
379
+ ):
380
+ """Given a nested list or dict return the desired value at key expanding
381
+ callable nodes if necessary and :attr:`expand` is ``True``. The expansion
382
+ is done in-place.
383
+
384
+ Parameters
385
+ ----------
386
+ list_or_dict : list or dict
387
+ Possibly nested list or dictionary.
388
+ key : str
389
+ key/to/value, path like string describing all keys necessary to
390
+ consider to get to the desired value. List indices can also be
391
+ passed here.
392
+ splitval : str
393
+ String that defines the delimiter between keys of the
394
+ different depth levels in `key`.
395
+ default : obj
396
+ Value returned if :attr:`key` is not found.
397
+ expand : bool
398
+ Whether to expand callable nodes on the path or not.
399
+
400
+ Returns
401
+ -------
402
+ The desired value or if :attr:`default` is not ``None`` and the
403
+ :attr:`key` is not found returns ``default``.
404
+
405
+ Raises
406
+ ------
407
+ Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
408
+ ``None``.
409
+ """
410
+
411
+ keys = key.split(splitval)
412
+
413
+ success = True
414
+ try:
415
+ visited = []
416
+ parent = None
417
+ last_key = None
418
+ for key in keys:
419
+ if callable(list_or_dict):
420
+ if not expand:
421
+ raise KeyNotFoundError(
422
+ ValueError(
423
+ "Trying to get past callable node with expand=False."
424
+ ),
425
+ keys=keys,
426
+ visited=visited,
427
+ )
428
+ list_or_dict = list_or_dict()
429
+ parent[last_key] = list_or_dict
430
+
431
+ last_key = key
432
+ parent = list_or_dict
433
+
434
+ try:
435
+ if isinstance(list_or_dict, dict):
436
+ list_or_dict = list_or_dict[key]
437
+ else:
438
+ list_or_dict = list_or_dict[int(key)]
439
+ except (KeyError, IndexError, ValueError) as e:
440
+ raise KeyNotFoundError(e, keys=keys, visited=visited)
441
+
442
+ visited += [key]
443
+ # final expansion of retrieved value
444
+ if expand and callable(list_or_dict):
445
+ list_or_dict = list_or_dict()
446
+ parent[last_key] = list_or_dict
447
+ except KeyNotFoundError as e:
448
+ if default is None:
449
+ raise e
450
+ else:
451
+ list_or_dict = default
452
+ success = False
453
+
454
+ if not pass_success:
455
+ return list_or_dict
456
+ else:
457
+ return list_or_dict, success
video_generation_demo.ipynb ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "import json\n",
11
+ "import torch\n",
12
+ "import numpy as np\n",
13
+ "import PIL\n",
14
+ "from PIL import Image\n",
15
+ "from IPython.display import HTML\n",
16
+ "from pyramid_dit import PyramidDiTForVideoGeneration\n",
17
+ "from IPython.display import Image as ipython_image\n",
18
+ "from diffusers.utils import load_image, export_to_video, export_to_gif"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "variant='diffusion_transformer_768p' # For high resolution\n",
28
+ "# variant='diffusion_transformer_384p' # For low resolution\n",
29
+ "\n",
30
+ "model_path = \"/home/jinyang06/models/pyramid-flow\" # The downloaded checkpoint dir\n",
31
+ "model_dtype = 'bf16'\n",
32
+ "\n",
33
+ "device_id = 0\n",
34
+ "torch.cuda.set_device(device_id)\n",
35
+ "\n",
36
+ "model = PyramidDiTForVideoGeneration(\n",
37
+ " model_path,\n",
38
+ " model_dtype,\n",
39
+ " model_variant=variant,\n",
40
+ ")\n",
41
+ "\n",
42
+ "model.vae.to(\"cuda\")\n",
43
+ "model.dit.to(\"cuda\")\n",
44
+ "model.text_encoder.to(\"cuda\")\n",
45
+ "\n",
46
+ "if model_dtype == \"bf16\":\n",
47
+ " torch_dtype = torch.bfloat16 \n",
48
+ "elif model_dtype == \"fp16\":\n",
49
+ " torch_dtype = torch.float16\n",
50
+ "else:\n",
51
+ " torch_dtype = torch.float32\n",
52
+ "\n",
53
+ "\n",
54
+ "def show_video(ori_path, rec_path, width=\"100%\"):\n",
55
+ " html = ''\n",
56
+ " if ori_path is not None:\n",
57
+ " html += f\"\"\"<video controls=\"\" name=\"media\" data-fullscreen-container=\"true\" width=\"{width}\">\n",
58
+ " <source src=\"{ori_path}\" type=\"video/mp4\">\n",
59
+ " </video>\n",
60
+ " \"\"\"\n",
61
+ " \n",
62
+ " html += f\"\"\"<video controls=\"\" name=\"media\" data-fullscreen-container=\"true\" width=\"{width}\">\n",
63
+ " <source src=\"{rec_path}\" type=\"video/mp4\">\n",
64
+ " </video>\n",
65
+ " \"\"\"\n",
66
+ " return HTML(html)"
67
+ ]
68
+ },
69
+ {
70
+ "attachments": {},
71
+ "cell_type": "markdown",
72
+ "metadata": {},
73
+ "source": [
74
+ "#### Text-to-Video"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": null,
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "prompt = \"A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors\"\n",
84
+ "\n",
85
+ "# used for 384p model variant\n",
86
+ "# width = 640\n",
87
+ "# height = 384\n",
88
+ "\n",
89
+ "# used for 768p model variant\n",
90
+ "width = 1280\n",
91
+ "height = 768\n",
92
+ "\n",
93
+ "temp = 16 # temp in [1, 31] <=> frame in [1, 241] <=> duration in [0, 10s]\n",
94
+ "\n",
95
+ "model.vae.enable_tiling()\n",
96
+ "\n",
97
+ "with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):\n",
98
+ " frames = model.generate(\n",
99
+ " prompt=prompt,\n",
100
+ " num_inference_steps=[20, 20, 20],\n",
101
+ " video_num_inference_steps=[10, 10, 10],\n",
102
+ " height=height,\n",
103
+ " width=width,\n",
104
+ " temp=temp,\n",
105
+ " guidance_scale=9.0, # The guidance for the first frame\n",
106
+ " video_guidance_scale=5.0, # The guidance for the other video latent\n",
107
+ " output_type=\"pil\",\n",
108
+ " save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed\n",
109
+ " )\n",
110
+ "\n",
111
+ "export_to_video(frames, \"./text_to_video_sample.mp4\", fps=24)\n",
112
+ "show_video(None, \"./text_to_video_sample.mp4\", \"70%\")"
113
+ ]
114
+ },
115
+ {
116
+ "attachments": {},
117
+ "cell_type": "markdown",
118
+ "metadata": {},
119
+ "source": [
120
+ "#### Image-to-Video"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "metadata": {},
127
+ "outputs": [],
128
+ "source": [
129
+ "image_path = 'assets/the_great_wall.jpg'\n",
130
+ "image = Image.open(image_path).convert(\"RGB\")\n",
131
+ "\n",
132
+ "width = 1280\n",
133
+ "height = 768\n",
134
+ "temp = 16\n",
135
+ "\n",
136
+ "image = image.resize((width, height))\n",
137
+ "\n",
138
+ "display(image)\n",
139
+ "\n",
140
+ "prompt = \"FPV flying over the Great Wall\"\n",
141
+ "\n",
142
+ "with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):\n",
143
+ " frames = model.generate_i2v(\n",
144
+ " prompt=prompt,\n",
145
+ " input_image=image,\n",
146
+ " num_inference_steps=[10, 10, 10],\n",
147
+ " temp=temp,\n",
148
+ " guidance_scale=7.0,\n",
149
+ " video_guidance_scale=4.0,\n",
150
+ " output_type=\"pil\",\n",
151
+ " save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed\n",
152
+ " )\n",
153
+ "\n",
154
+ "export_to_video(frames, \"./image_to_video_sample.mp4\", fps=24)\n",
155
+ "show_video(None, \"./image_to_video_sample.mp4\", \"70%\")"
156
+ ]
157
+ }
158
+ ],
159
+ "metadata": {
160
+ "kernelspec": {
161
+ "display_name": "Python 3",
162
+ "language": "python",
163
+ "name": "python3"
164
+ },
165
+ "language_info": {
166
+ "codemirror_mode": {
167
+ "name": "ipython",
168
+ "version": 3
169
+ },
170
+ "file_extension": ".py",
171
+ "mimetype": "text/x-python",
172
+ "name": "python",
173
+ "nbconvert_exporter": "python",
174
+ "pygments_lexer": "ipython3",
175
+ "version": "3.8.10"
176
+ },
177
+ "orig_nbformat": 4
178
+ },
179
+ "nbformat": 4,
180
+ "nbformat_minor": 2
181
+ }
video_vae/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .modeling_loss import LPIPSWithDiscriminator
2
+ from .modeling_causal_vae import CausalVideoVAE
video_vae/context_parallel_ops.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from cogvideoX
2
+ import torch
3
+ import torch.nn as nn
4
+ import math
5
+
6
+ from utils import (
7
+ get_context_parallel_group,
8
+ get_context_parallel_rank,
9
+ get_context_parallel_world_size,
10
+ get_context_parallel_group_rank,
11
+ )
12
+
13
+
14
+ def _conv_split(input_, dim=2, kernel_size=1):
15
+ cp_world_size = get_context_parallel_world_size()
16
+
17
+ # Bypass the function if context parallel is 1
18
+ if cp_world_size == 1:
19
+ return input_
20
+
21
+ # print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape)
22
+
23
+ cp_rank = get_context_parallel_rank()
24
+
25
+ dim_size = (input_.size()[dim] - kernel_size) // cp_world_size
26
+
27
+ if cp_rank == 0:
28
+ output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
29
+ else:
30
+ # output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0)
31
+ output = input_.transpose(dim, 0)[
32
+ cp_rank * dim_size + kernel_size : (cp_rank + 1) * dim_size + kernel_size
33
+ ].transpose(dim, 0)
34
+ output = output.contiguous()
35
+
36
+ # print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
37
+
38
+ return output
39
+
40
+
41
+ def _conv_gather(input_, dim=2, kernel_size=1):
42
+ cp_world_size = get_context_parallel_world_size()
43
+
44
+ # Bypass the function if context parallel is 1
45
+ if cp_world_size == 1:
46
+ return input_
47
+
48
+ group = get_context_parallel_group()
49
+ cp_rank = get_context_parallel_rank()
50
+
51
+ # print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
52
+
53
+ input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous()
54
+ if cp_rank == 0:
55
+ input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous()
56
+ else:
57
+ input_ = input_.transpose(0, dim)[max(kernel_size - 1, 0) :].transpose(0, dim).contiguous()
58
+
59
+ tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [
60
+ torch.empty_like(input_) for _ in range(cp_world_size - 1)
61
+ ]
62
+ if cp_rank == 0:
63
+ input_ = torch.cat([input_first_kernel_, input_], dim=dim)
64
+
65
+ tensor_list[cp_rank] = input_
66
+ torch.distributed.all_gather(tensor_list, input_, group=group)
67
+
68
+ # Note: torch.cat already creates a contiguous tensor.
69
+ output = torch.cat(tensor_list, dim=dim).contiguous()
70
+
71
+ # print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)
72
+
73
+ return output
74
+
75
+
76
+ def _cp_pass_from_previous_rank(input_, dim, kernel_size):
77
+ # Bypass the function if kernel size is 1
78
+ if kernel_size == 1:
79
+ return input_
80
+
81
+ group = get_context_parallel_group()
82
+ cp_rank = get_context_parallel_rank()
83
+ cp_group_rank = get_context_parallel_group_rank()
84
+ cp_world_size = get_context_parallel_world_size()
85
+
86
+ # print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
87
+
88
+ global_rank = torch.distributed.get_rank()
89
+ global_world_size = torch.distributed.get_world_size()
90
+
91
+ input_ = input_.transpose(0, dim)
92
+
93
+ # pass from last rank
94
+ send_rank = global_rank + 1
95
+ recv_rank = global_rank - 1
96
+ if send_rank % cp_world_size == 0:
97
+ send_rank -= cp_world_size
98
+ if recv_rank % cp_world_size == cp_world_size - 1:
99
+ recv_rank += cp_world_size
100
+
101
+ recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
102
+ if cp_rank < cp_world_size - 1:
103
+ req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
104
+ if cp_rank > 0:
105
+ req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
106
+
107
+ if cp_rank == 0:
108
+ input_ = torch.cat([torch.zeros_like(input_[:1])] * (kernel_size - 1) + [input_], dim=0)
109
+ else:
110
+ req_recv.wait()
111
+ input_ = torch.cat([recv_buffer, input_], dim=0)
112
+
113
+ input_ = input_.transpose(0, dim).contiguous()
114
+ return input_
115
+
116
+
117
+ def _drop_from_previous_rank(input_, dim, kernel_size):
118
+ input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim)
119
+ return input_
120
+
121
+
122
+ class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function):
123
+ @staticmethod
124
+ def forward(ctx, input_, dim, kernel_size):
125
+ ctx.dim = dim
126
+ ctx.kernel_size = kernel_size
127
+ return _conv_split(input_, dim, kernel_size)
128
+
129
+ @staticmethod
130
+ def backward(ctx, grad_output):
131
+ return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None
132
+
133
+
134
+ class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function):
135
+ @staticmethod
136
+ def forward(ctx, input_, dim, kernel_size):
137
+ ctx.dim = dim
138
+ ctx.kernel_size = kernel_size
139
+ return _conv_gather(input_, dim, kernel_size)
140
+
141
+ @staticmethod
142
+ def backward(ctx, grad_output):
143
+ return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None
144
+
145
+
146
+ class _CPConvolutionPassFromPreviousRank(torch.autograd.Function):
147
+ @staticmethod
148
+ def forward(ctx, input_, dim, kernel_size):
149
+ ctx.dim = dim
150
+ ctx.kernel_size = kernel_size
151
+ return _cp_pass_from_previous_rank(input_, dim, kernel_size)
152
+
153
+ @staticmethod
154
+ def backward(ctx, grad_output):
155
+ return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None
156
+
157
+
158
+ def conv_scatter_to_context_parallel_region(input_, dim, kernel_size):
159
+ return _ConvolutionScatterToContextParallelRegion.apply(input_, dim, kernel_size)
160
+
161
+
162
+ def conv_gather_from_context_parallel_region(input_, dim, kernel_size):
163
+ return _ConvolutionGatherFromContextParallelRegion.apply(input_, dim, kernel_size)
164
+
165
+
166
+ def cp_pass_from_previous_rank(input_, dim, kernel_size):
167
+ return _CPConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size)
168
+
169
+
170
+
171
+
172
+
video_vae/modeling_block.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional, Tuple, Union
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+ from einops import rearrange
21
+
22
+ from diffusers.utils import logging
23
+ from diffusers.models.attention_processor import Attention
24
+ from .modeling_resnet import (
25
+ Downsample2D, ResnetBlock2D, CausalResnetBlock3D, Upsample2D,
26
+ TemporalDownsample2x, TemporalUpsample2x,
27
+ CausalDownsample2x, CausalTemporalDownsample2x,
28
+ CausalUpsample2x, CausalTemporalUpsample2x,
29
+ )
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ def get_input_layer(
35
+ in_channels: int,
36
+ out_channels: int,
37
+ norm_num_groups: int,
38
+ layer_type: str,
39
+ norm_type: str = 'group',
40
+ affine: bool = True,
41
+ ):
42
+ if layer_type == 'conv':
43
+ input_layer = nn.Conv3d(
44
+ in_channels,
45
+ out_channels,
46
+ kernel_size=3,
47
+ stride=1,
48
+ padding=1,
49
+ )
50
+
51
+ elif layer_type == 'pixel_shuffle':
52
+ input_layer = nn.Sequential(
53
+ nn.PixelUnshuffle(2),
54
+ nn.Conv2d(in_channels * 4, out_channels, kernel_size=1),
55
+ )
56
+ else:
57
+ raise NotImplementedError(f"Not support input layer {layer_type}")
58
+
59
+ return input_layer
60
+
61
+
62
+ def get_output_layer(
63
+ in_channels: int,
64
+ out_channels: int,
65
+ norm_num_groups: int,
66
+ layer_type: str,
67
+ norm_type: str = 'group',
68
+ affine: bool = True,
69
+ ):
70
+ if layer_type == 'norm_act_conv':
71
+ output_layer = nn.Sequential(
72
+ nn.GroupNorm(num_channels=in_channels, num_groups=norm_num_groups, eps=1e-6, affine=affine),
73
+ nn.SiLU(),
74
+ nn.Conv3d(in_channels, out_channels, 3, stride=1, padding=1),
75
+ )
76
+
77
+ elif layer_type == 'pixel_shuffle':
78
+ output_layer = nn.Sequential(
79
+ nn.Conv2d(in_channels, out_channels * 4, kernel_size=1),
80
+ nn.PixelShuffle(2),
81
+ )
82
+
83
+ else:
84
+ raise NotImplementedError(f"Not support output layer {layer_type}")
85
+
86
+ return output_layer
87
+
88
+
89
+ def get_down_block(
90
+ down_block_type: str,
91
+ num_layers: int,
92
+ in_channels: int,
93
+ out_channels: int = None,
94
+ temb_channels: int = None,
95
+ add_spatial_downsample: bool = None,
96
+ add_temporal_downsample: bool = None,
97
+ resnet_eps: float = 1e-6,
98
+ resnet_act_fn: str = 'silu',
99
+ resnet_groups: Optional[int] = None,
100
+ downsample_padding: Optional[int] = None,
101
+ resnet_time_scale_shift: str = "default",
102
+ attention_head_dim: Optional[int] = None,
103
+ dropout: float = 0.0,
104
+ norm_affline: bool = True,
105
+ norm_layer: str = 'layer',
106
+ ):
107
+
108
+ if down_block_type == "DownEncoderBlock2D":
109
+ return DownEncoderBlock2D(
110
+ num_layers=num_layers,
111
+ in_channels=in_channels,
112
+ out_channels=out_channels,
113
+ dropout=dropout,
114
+ add_spatial_downsample=add_spatial_downsample,
115
+ add_temporal_downsample=add_temporal_downsample,
116
+ resnet_eps=resnet_eps,
117
+ resnet_act_fn=resnet_act_fn,
118
+ resnet_groups=resnet_groups,
119
+ downsample_padding=downsample_padding,
120
+ resnet_time_scale_shift=resnet_time_scale_shift,
121
+ )
122
+
123
+ elif down_block_type == "DownEncoderBlockCausal3D":
124
+ return DownEncoderBlockCausal3D(
125
+ num_layers=num_layers,
126
+ in_channels=in_channels,
127
+ out_channels=out_channels,
128
+ dropout=dropout,
129
+ add_spatial_downsample=add_spatial_downsample,
130
+ add_temporal_downsample=add_temporal_downsample,
131
+ resnet_eps=resnet_eps,
132
+ resnet_act_fn=resnet_act_fn,
133
+ resnet_groups=resnet_groups,
134
+ downsample_padding=downsample_padding,
135
+ resnet_time_scale_shift=resnet_time_scale_shift,
136
+ )
137
+
138
+ raise ValueError(f"{down_block_type} does not exist.")
139
+
140
+
141
+ def get_up_block(
142
+ up_block_type: str,
143
+ num_layers: int,
144
+ in_channels: int,
145
+ out_channels: int,
146
+ prev_output_channel: int = None,
147
+ temb_channels: int = None,
148
+ add_spatial_upsample: bool = None,
149
+ add_temporal_upsample: bool = None,
150
+ resnet_eps: float = 1e-6,
151
+ resnet_act_fn: str = 'silu',
152
+ resolution_idx: Optional[int] = None,
153
+ resnet_groups: Optional[int] = None,
154
+ resnet_time_scale_shift: str = "default",
155
+ attention_head_dim: Optional[int] = None,
156
+ dropout: float = 0.0,
157
+ interpolate: bool = True,
158
+ norm_affline: bool = True,
159
+ norm_layer: str = 'layer',
160
+ ) -> nn.Module:
161
+
162
+ if up_block_type == "UpDecoderBlock2D":
163
+ return UpDecoderBlock2D(
164
+ num_layers=num_layers,
165
+ in_channels=in_channels,
166
+ out_channels=out_channels,
167
+ resolution_idx=resolution_idx,
168
+ dropout=dropout,
169
+ add_spatial_upsample=add_spatial_upsample,
170
+ add_temporal_upsample=add_temporal_upsample,
171
+ resnet_eps=resnet_eps,
172
+ resnet_act_fn=resnet_act_fn,
173
+ resnet_groups=resnet_groups,
174
+ resnet_time_scale_shift=resnet_time_scale_shift,
175
+ temb_channels=temb_channels,
176
+ interpolate=interpolate,
177
+ )
178
+
179
+ elif up_block_type == "UpDecoderBlockCausal3D":
180
+ return UpDecoderBlockCausal3D(
181
+ num_layers=num_layers,
182
+ in_channels=in_channels,
183
+ out_channels=out_channels,
184
+ resolution_idx=resolution_idx,
185
+ dropout=dropout,
186
+ add_spatial_upsample=add_spatial_upsample,
187
+ add_temporal_upsample=add_temporal_upsample,
188
+ resnet_eps=resnet_eps,
189
+ resnet_act_fn=resnet_act_fn,
190
+ resnet_groups=resnet_groups,
191
+ resnet_time_scale_shift=resnet_time_scale_shift,
192
+ temb_channels=temb_channels,
193
+ interpolate=interpolate,
194
+ )
195
+
196
+ raise ValueError(f"{up_block_type} does not exist.")
197
+
198
+
199
+
200
+ class UNetMidBlock2D(nn.Module):
201
+ """
202
+ A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
203
+
204
+ Args:
205
+ in_channels (`int`): The number of input channels.
206
+ temb_channels (`int`): The number of temporal embedding channels.
207
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
208
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
209
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
210
+ resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
211
+ The type of normalization to apply to the time embeddings. This can help to improve the performance of the
212
+ model on tasks with long-range temporal dependencies.
213
+ resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
214
+ resnet_groups (`int`, *optional*, defaults to 32):
215
+ The number of groups to use in the group normalization layers of the resnet blocks.
216
+ attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
217
+ resnet_pre_norm (`bool`, *optional*, defaults to `True`):
218
+ Whether to use pre-normalization for the resnet blocks.
219
+ add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
220
+ attention_head_dim (`int`, *optional*, defaults to 1):
221
+ Dimension of a single attention head. The number of attention heads is determined based on this value and
222
+ the number of input channels.
223
+ output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
224
+
225
+ Returns:
226
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
227
+ in_channels, height, width)`.
228
+
229
+ """
230
+
231
+ def __init__(
232
+ self,
233
+ in_channels: int,
234
+ temb_channels: int,
235
+ dropout: float = 0.0,
236
+ num_layers: int = 1,
237
+ resnet_eps: float = 1e-6,
238
+ resnet_time_scale_shift: str = "default", # default, spatial
239
+ resnet_act_fn: str = "swish",
240
+ resnet_groups: int = 32,
241
+ attn_groups: Optional[int] = None,
242
+ resnet_pre_norm: bool = True,
243
+ add_attention: bool = True,
244
+ attention_head_dim: int = 1,
245
+ output_scale_factor: float = 1.0,
246
+ ):
247
+ super().__init__()
248
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
249
+ self.add_attention = add_attention
250
+
251
+ if attn_groups is None:
252
+ attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
253
+
254
+ # there is always at least one resnet
255
+ resnets = [
256
+ ResnetBlock2D(
257
+ in_channels=in_channels,
258
+ out_channels=in_channels,
259
+ temb_channels=temb_channels,
260
+ eps=resnet_eps,
261
+ groups=resnet_groups,
262
+ dropout=dropout,
263
+ time_embedding_norm=resnet_time_scale_shift,
264
+ non_linearity=resnet_act_fn,
265
+ output_scale_factor=output_scale_factor,
266
+ pre_norm=resnet_pre_norm,
267
+ )
268
+ ]
269
+ attentions = []
270
+
271
+ if attention_head_dim is None:
272
+ logger.warn(
273
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
274
+ )
275
+ attention_head_dim = in_channels
276
+
277
+ for _ in range(num_layers):
278
+ if self.add_attention:
279
+ # Spatial attention
280
+ attentions.append(
281
+ Attention(
282
+ in_channels,
283
+ heads=in_channels // attention_head_dim,
284
+ dim_head=attention_head_dim,
285
+ rescale_output_factor=output_scale_factor,
286
+ eps=resnet_eps,
287
+ norm_num_groups=attn_groups,
288
+ spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
289
+ residual_connection=True,
290
+ bias=True,
291
+ upcast_softmax=True,
292
+ _from_deprecated_attn_block=True,
293
+ )
294
+ )
295
+ else:
296
+ attentions.append(None)
297
+
298
+ resnets.append(
299
+ ResnetBlock2D(
300
+ in_channels=in_channels,
301
+ out_channels=in_channels,
302
+ temb_channels=temb_channels,
303
+ eps=resnet_eps,
304
+ groups=resnet_groups,
305
+ dropout=dropout,
306
+ time_embedding_norm=resnet_time_scale_shift,
307
+ non_linearity=resnet_act_fn,
308
+ output_scale_factor=output_scale_factor,
309
+ pre_norm=resnet_pre_norm,
310
+ )
311
+ )
312
+
313
+ self.attentions = nn.ModuleList(attentions)
314
+ self.resnets = nn.ModuleList(resnets)
315
+
316
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
317
+ hidden_states = self.resnets[0](hidden_states, temb)
318
+ t = hidden_states.shape[2]
319
+
320
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
321
+ if attn is not None:
322
+ hidden_states = rearrange(hidden_states, 'b c t h w -> b t c h w')
323
+ hidden_states = rearrange(hidden_states, 'b t c h w -> (b t) c h w')
324
+ hidden_states = attn(hidden_states, temb=temb)
325
+ hidden_states = rearrange(hidden_states, '(b t) c h w -> b t c h w', t=t)
326
+ hidden_states = rearrange(hidden_states, 'b t c h w -> b c t h w')
327
+
328
+ hidden_states = resnet(hidden_states, temb)
329
+
330
+ return hidden_states
331
+
332
+
333
+ class CausalUNetMidBlock2D(nn.Module):
334
+ """
335
+ A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
336
+
337
+ Args:
338
+ in_channels (`int`): The number of input channels.
339
+ temb_channels (`int`): The number of temporal embedding channels.
340
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
341
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
342
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
343
+ resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
344
+ The type of normalization to apply to the time embeddings. This can help to improve the performance of the
345
+ model on tasks with long-range temporal dependencies.
346
+ resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
347
+ resnet_groups (`int`, *optional*, defaults to 32):
348
+ The number of groups to use in the group normalization layers of the resnet blocks.
349
+ attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
350
+ resnet_pre_norm (`bool`, *optional*, defaults to `True`):
351
+ Whether to use pre-normalization for the resnet blocks.
352
+ add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
353
+ attention_head_dim (`int`, *optional*, defaults to 1):
354
+ Dimension of a single attention head. The number of attention heads is determined based on this value and
355
+ the number of input channels.
356
+ output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
357
+
358
+ Returns:
359
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
360
+ in_channels, height, width)`.
361
+
362
+ """
363
+
364
+ def __init__(
365
+ self,
366
+ in_channels: int,
367
+ temb_channels: int,
368
+ dropout: float = 0.0,
369
+ num_layers: int = 1,
370
+ resnet_eps: float = 1e-6,
371
+ resnet_time_scale_shift: str = "default", # default, spatial
372
+ resnet_act_fn: str = "swish",
373
+ resnet_groups: int = 32,
374
+ attn_groups: Optional[int] = None,
375
+ resnet_pre_norm: bool = True,
376
+ add_attention: bool = True,
377
+ attention_head_dim: int = 1,
378
+ output_scale_factor: float = 1.0,
379
+ ):
380
+ super().__init__()
381
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
382
+ self.add_attention = add_attention
383
+
384
+ if attn_groups is None:
385
+ attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
386
+
387
+ # there is always at least one resnet
388
+ resnets = [
389
+ CausalResnetBlock3D(
390
+ in_channels=in_channels,
391
+ out_channels=in_channels,
392
+ temb_channels=temb_channels,
393
+ eps=resnet_eps,
394
+ groups=resnet_groups,
395
+ dropout=dropout,
396
+ time_embedding_norm=resnet_time_scale_shift,
397
+ non_linearity=resnet_act_fn,
398
+ output_scale_factor=output_scale_factor,
399
+ pre_norm=resnet_pre_norm,
400
+ )
401
+ ]
402
+ attentions = []
403
+
404
+ if attention_head_dim is None:
405
+ logger.warn(
406
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
407
+ )
408
+ attention_head_dim = in_channels
409
+
410
+ for _ in range(num_layers):
411
+ if self.add_attention:
412
+ # Spatial attention
413
+ attentions.append(
414
+ Attention(
415
+ in_channels,
416
+ heads=in_channels // attention_head_dim,
417
+ dim_head=attention_head_dim,
418
+ rescale_output_factor=output_scale_factor,
419
+ eps=resnet_eps,
420
+ norm_num_groups=attn_groups,
421
+ spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
422
+ residual_connection=True,
423
+ bias=True,
424
+ upcast_softmax=True,
425
+ _from_deprecated_attn_block=True,
426
+ )
427
+ )
428
+ else:
429
+ attentions.append(None)
430
+
431
+ resnets.append(
432
+ CausalResnetBlock3D(
433
+ in_channels=in_channels,
434
+ out_channels=in_channels,
435
+ temb_channels=temb_channels,
436
+ eps=resnet_eps,
437
+ groups=resnet_groups,
438
+ dropout=dropout,
439
+ time_embedding_norm=resnet_time_scale_shift,
440
+ non_linearity=resnet_act_fn,
441
+ output_scale_factor=output_scale_factor,
442
+ pre_norm=resnet_pre_norm,
443
+ )
444
+ )
445
+
446
+ self.attentions = nn.ModuleList(attentions)
447
+ self.resnets = nn.ModuleList(resnets)
448
+
449
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None,
450
+ is_init_image=True, temporal_chunk=False) -> torch.FloatTensor:
451
+ hidden_states = self.resnets[0](hidden_states, temb, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
452
+ t = hidden_states.shape[2]
453
+
454
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
455
+ if attn is not None:
456
+ hidden_states = rearrange(hidden_states, 'b c t h w -> b t c h w')
457
+ hidden_states = rearrange(hidden_states, 'b t c h w -> (b t) c h w')
458
+ hidden_states = attn(hidden_states, temb=temb)
459
+ hidden_states = rearrange(hidden_states, '(b t) c h w -> b t c h w', t=t)
460
+ hidden_states = rearrange(hidden_states, 'b t c h w -> b c t h w')
461
+
462
+ hidden_states = resnet(hidden_states, temb, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
463
+
464
+ return hidden_states
465
+
466
+
467
+ class DownEncoderBlockCausal3D(nn.Module):
468
+ def __init__(
469
+ self,
470
+ in_channels: int,
471
+ out_channels: int,
472
+ dropout: float = 0.0,
473
+ num_layers: int = 1,
474
+ resnet_eps: float = 1e-6,
475
+ resnet_time_scale_shift: str = "default",
476
+ resnet_act_fn: str = "swish",
477
+ resnet_groups: int = 32,
478
+ resnet_pre_norm: bool = True,
479
+ output_scale_factor: float = 1.0,
480
+ add_spatial_downsample: bool = True,
481
+ add_temporal_downsample: bool = False,
482
+ downsample_padding: int = 1,
483
+ ):
484
+ super().__init__()
485
+ resnets = []
486
+
487
+ for i in range(num_layers):
488
+ in_channels = in_channels if i == 0 else out_channels
489
+ resnets.append(
490
+ CausalResnetBlock3D(
491
+ in_channels=in_channels,
492
+ out_channels=out_channels,
493
+ temb_channels=None,
494
+ eps=resnet_eps,
495
+ groups=resnet_groups,
496
+ dropout=dropout,
497
+ time_embedding_norm=resnet_time_scale_shift,
498
+ non_linearity=resnet_act_fn,
499
+ output_scale_factor=output_scale_factor,
500
+ pre_norm=resnet_pre_norm,
501
+ )
502
+ )
503
+
504
+ self.resnets = nn.ModuleList(resnets)
505
+
506
+ if add_spatial_downsample:
507
+ self.downsamplers = nn.ModuleList(
508
+ [
509
+ CausalDownsample2x(
510
+ out_channels, use_conv=True, out_channels=out_channels,
511
+ )
512
+ ]
513
+ )
514
+ else:
515
+ self.downsamplers = None
516
+
517
+ if add_temporal_downsample:
518
+ self.temporal_downsamplers = nn.ModuleList(
519
+ [
520
+ CausalTemporalDownsample2x(
521
+ out_channels, use_conv=True, out_channels=out_channels,
522
+ )
523
+ ]
524
+ )
525
+ else:
526
+ self.temporal_downsamplers = None
527
+
528
+ def forward(self, hidden_states: torch.FloatTensor, is_init_image=True, temporal_chunk=False) -> torch.FloatTensor:
529
+ for resnet in self.resnets:
530
+ hidden_states = resnet(hidden_states, temb=None, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
531
+
532
+ if self.downsamplers is not None:
533
+ for downsampler in self.downsamplers:
534
+ hidden_states = downsampler(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
535
+
536
+ if self.temporal_downsamplers is not None:
537
+ for temporal_downsampler in self.temporal_downsamplers:
538
+ hidden_states = temporal_downsampler(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
539
+
540
+ return hidden_states
541
+
542
+
543
+ class DownEncoderBlock2D(nn.Module):
544
+ def __init__(
545
+ self,
546
+ in_channels: int,
547
+ out_channels: int,
548
+ dropout: float = 0.0,
549
+ num_layers: int = 1,
550
+ resnet_eps: float = 1e-6,
551
+ resnet_time_scale_shift: str = "default",
552
+ resnet_act_fn: str = "swish",
553
+ resnet_groups: int = 32,
554
+ resnet_pre_norm: bool = True,
555
+ output_scale_factor: float = 1.0,
556
+ add_spatial_downsample: bool = True,
557
+ add_temporal_downsample: bool = False,
558
+ downsample_padding: int = 1,
559
+ ):
560
+ super().__init__()
561
+ resnets = []
562
+
563
+ for i in range(num_layers):
564
+ in_channels = in_channels if i == 0 else out_channels
565
+ resnets.append(
566
+ ResnetBlock2D(
567
+ in_channels=in_channels,
568
+ out_channels=out_channels,
569
+ temb_channels=None,
570
+ eps=resnet_eps,
571
+ groups=resnet_groups,
572
+ dropout=dropout,
573
+ time_embedding_norm=resnet_time_scale_shift,
574
+ non_linearity=resnet_act_fn,
575
+ output_scale_factor=output_scale_factor,
576
+ pre_norm=resnet_pre_norm,
577
+ )
578
+ )
579
+
580
+ self.resnets = nn.ModuleList(resnets)
581
+
582
+ if add_spatial_downsample:
583
+ self.downsamplers = nn.ModuleList(
584
+ [
585
+ Downsample2D(
586
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
587
+ )
588
+ ]
589
+ )
590
+ else:
591
+ self.downsamplers = None
592
+
593
+ if add_temporal_downsample:
594
+ self.temporal_downsamplers = nn.ModuleList(
595
+ [
596
+ TemporalDownsample2x(
597
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding,
598
+ )
599
+ ]
600
+ )
601
+ else:
602
+ self.temporal_downsamplers = None
603
+
604
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
605
+ for resnet in self.resnets:
606
+ hidden_states = resnet(hidden_states, temb=None)
607
+
608
+ if self.downsamplers is not None:
609
+ for downsampler in self.downsamplers:
610
+ hidden_states = downsampler(hidden_states)
611
+
612
+ if self.temporal_downsamplers is not None:
613
+ for temporal_downsampler in self.temporal_downsamplers:
614
+ hidden_states = temporal_downsampler(hidden_states)
615
+
616
+ return hidden_states
617
+
618
+
619
+ class UpDecoderBlock2D(nn.Module):
620
+ def __init__(
621
+ self,
622
+ in_channels: int,
623
+ out_channels: int,
624
+ resolution_idx: Optional[int] = None,
625
+ dropout: float = 0.0,
626
+ num_layers: int = 1,
627
+ resnet_eps: float = 1e-6,
628
+ resnet_time_scale_shift: str = "default", # default, spatial
629
+ resnet_act_fn: str = "swish",
630
+ resnet_groups: int = 32,
631
+ resnet_pre_norm: bool = True,
632
+ output_scale_factor: float = 1.0,
633
+ add_spatial_upsample: bool = True,
634
+ add_temporal_upsample: bool = False,
635
+ temb_channels: Optional[int] = None,
636
+ interpolate: bool = True,
637
+ ):
638
+ super().__init__()
639
+ resnets = []
640
+
641
+ for i in range(num_layers):
642
+ input_channels = in_channels if i == 0 else out_channels
643
+
644
+ resnets.append(
645
+ ResnetBlock2D(
646
+ in_channels=input_channels,
647
+ out_channels=out_channels,
648
+ temb_channels=temb_channels,
649
+ eps=resnet_eps,
650
+ groups=resnet_groups,
651
+ dropout=dropout,
652
+ time_embedding_norm=resnet_time_scale_shift,
653
+ non_linearity=resnet_act_fn,
654
+ output_scale_factor=output_scale_factor,
655
+ pre_norm=resnet_pre_norm,
656
+ )
657
+ )
658
+
659
+ self.resnets = nn.ModuleList(resnets)
660
+
661
+ if add_spatial_upsample:
662
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels, interpolate=interpolate)])
663
+ else:
664
+ self.upsamplers = None
665
+
666
+ if add_temporal_upsample:
667
+ self.temporal_upsamplers = nn.ModuleList([TemporalUpsample2x(out_channels, use_conv=True, out_channels=out_channels, interpolate=interpolate)])
668
+ else:
669
+ self.temporal_upsamplers = None
670
+
671
+ self.resolution_idx = resolution_idx
672
+
673
+ def forward(
674
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, is_image: bool = False,
675
+ ) -> torch.FloatTensor:
676
+ for resnet in self.resnets:
677
+ hidden_states = resnet(hidden_states, temb=temb, scale=scale)
678
+
679
+ if self.upsamplers is not None:
680
+ for upsampler in self.upsamplers:
681
+ hidden_states = upsampler(hidden_states)
682
+
683
+ if self.temporal_upsamplers is not None:
684
+ for temporal_upsampler in self.temporal_upsamplers:
685
+ hidden_states = temporal_upsampler(hidden_states, is_image=is_image)
686
+
687
+ return hidden_states
688
+
689
+
690
+ class UpDecoderBlockCausal3D(nn.Module):
691
+ def __init__(
692
+ self,
693
+ in_channels: int,
694
+ out_channels: int,
695
+ resolution_idx: Optional[int] = None,
696
+ dropout: float = 0.0,
697
+ num_layers: int = 1,
698
+ resnet_eps: float = 1e-6,
699
+ resnet_time_scale_shift: str = "default", # default, spatial
700
+ resnet_act_fn: str = "swish",
701
+ resnet_groups: int = 32,
702
+ resnet_pre_norm: bool = True,
703
+ output_scale_factor: float = 1.0,
704
+ add_spatial_upsample: bool = True,
705
+ add_temporal_upsample: bool = False,
706
+ temb_channels: Optional[int] = None,
707
+ interpolate: bool = True,
708
+ ):
709
+ super().__init__()
710
+ resnets = []
711
+
712
+ for i in range(num_layers):
713
+ input_channels = in_channels if i == 0 else out_channels
714
+
715
+ resnets.append(
716
+ CausalResnetBlock3D(
717
+ in_channels=input_channels,
718
+ out_channels=out_channels,
719
+ temb_channels=temb_channels,
720
+ eps=resnet_eps,
721
+ groups=resnet_groups,
722
+ dropout=dropout,
723
+ time_embedding_norm=resnet_time_scale_shift,
724
+ non_linearity=resnet_act_fn,
725
+ output_scale_factor=output_scale_factor,
726
+ pre_norm=resnet_pre_norm,
727
+ )
728
+ )
729
+
730
+ self.resnets = nn.ModuleList(resnets)
731
+
732
+ if add_spatial_upsample:
733
+ self.upsamplers = nn.ModuleList([CausalUpsample2x(out_channels, use_conv=True, out_channels=out_channels, interpolate=interpolate)])
734
+ else:
735
+ self.upsamplers = None
736
+
737
+ if add_temporal_upsample:
738
+ self.temporal_upsamplers = nn.ModuleList([CausalTemporalUpsample2x(out_channels, use_conv=True, out_channels=out_channels, interpolate=interpolate)])
739
+ else:
740
+ self.temporal_upsamplers = None
741
+
742
+ self.resolution_idx = resolution_idx
743
+
744
+ def forward(
745
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None,
746
+ is_init_image=True, temporal_chunk=False,
747
+ ) -> torch.FloatTensor:
748
+ for resnet in self.resnets:
749
+ hidden_states = resnet(hidden_states, temb=temb, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
750
+
751
+ if self.upsamplers is not None:
752
+ for upsampler in self.upsamplers:
753
+ hidden_states = upsampler(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
754
+
755
+ if self.temporal_upsamplers is not None:
756
+ for temporal_upsampler in self.temporal_upsamplers:
757
+ hidden_states = temporal_upsampler(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
758
+
759
+ return hidden_states
760
+
video_vae/modeling_causal_conv.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.utils.checkpoint import checkpoint
5
+ import torch.nn.functional as F
6
+ from collections import deque
7
+ from einops import rearrange
8
+ from timm.models.layers import trunc_normal_
9
+ from IPython import embed
10
+ from torch import Tensor
11
+
12
+ from utils import (
13
+ is_context_parallel_initialized,
14
+ get_context_parallel_group,
15
+ get_context_parallel_world_size,
16
+ get_context_parallel_rank,
17
+ get_context_parallel_group_rank,
18
+ )
19
+
20
+ from .context_parallel_ops import (
21
+ conv_scatter_to_context_parallel_region,
22
+ conv_gather_from_context_parallel_region,
23
+ cp_pass_from_previous_rank,
24
+ )
25
+
26
+
27
+ def divisible_by(num, den):
28
+ return (num % den) == 0
29
+
30
+ def cast_tuple(t, length = 1):
31
+ return t if isinstance(t, tuple) else ((t,) * length)
32
+
33
+ def is_odd(n):
34
+ return not divisible_by(n, 2)
35
+
36
+
37
+ class CausalGroupNorm(nn.GroupNorm):
38
+
39
+ def forward(self, x: Tensor) -> Tensor:
40
+ t = x.shape[2]
41
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
42
+ x = super().forward(x)
43
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
44
+ return x
45
+
46
+
47
+ class CausalConv3d(nn.Module):
48
+
49
+ def __init__(
50
+ self,
51
+ in_channels,
52
+ out_channels,
53
+ kernel_size: Union[int, Tuple[int, int, int]],
54
+ stride: Union[int, Tuple[int, int, int]] = 1,
55
+ pad_mode: str ='constant',
56
+ **kwargs
57
+ ):
58
+ super().__init__()
59
+ if isinstance(kernel_size, int):
60
+ kernel_size = cast_tuple(kernel_size, 3)
61
+
62
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
63
+ self.time_kernel_size = time_kernel_size
64
+ assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
65
+ dilation = kwargs.pop('dilation', 1)
66
+ self.pad_mode = pad_mode
67
+
68
+ if isinstance(stride, int):
69
+ stride = (stride, 1, 1)
70
+
71
+ time_pad = dilation * (time_kernel_size - 1)
72
+ height_pad = height_kernel_size // 2
73
+ width_pad = width_kernel_size // 2
74
+
75
+ self.temporal_stride = stride[0]
76
+ self.time_pad = time_pad
77
+ self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
78
+ self.time_uncausal_padding = (width_pad, width_pad, height_pad, height_pad, 0, 0)
79
+
80
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=0, dilation=dilation, **kwargs)
81
+ self.cache_front_feat = deque()
82
+
83
+ def _clear_context_parallel_cache(self):
84
+ del self.cache_front_feat
85
+ self.cache_front_feat = deque()
86
+
87
+ def _init_weights(self, m):
88
+ if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
89
+ trunc_normal_(m.weight, std=.02)
90
+ if m.bias is not None:
91
+ nn.init.constant_(m.bias, 0)
92
+ elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
93
+ nn.init.constant_(m.bias, 0)
94
+ nn.init.constant_(m.weight, 1.0)
95
+
96
+ def context_parallel_forward(self, x):
97
+ x = cp_pass_from_previous_rank(x, dim=2, kernel_size=self.time_kernel_size)
98
+
99
+ x = F.pad(x, self.time_uncausal_padding, mode='constant')
100
+
101
+ cp_rank = get_context_parallel_rank()
102
+ if cp_rank != 0:
103
+ if self.temporal_stride == 2 and self.time_kernel_size == 3:
104
+ x = x[:,:,1:]
105
+
106
+ x = self.conv(x)
107
+ return x
108
+
109
+ def forward(self, x, is_init_image=True, temporal_chunk=False):
110
+ # temporal_chunk: whether to use the temporal chunk
111
+
112
+ if is_context_parallel_initialized():
113
+ return self.context_parallel_forward(x)
114
+
115
+ pad_mode = self.pad_mode if self.time_pad < x.shape[2] else 'constant'
116
+
117
+ if not temporal_chunk:
118
+ x = F.pad(x, self.time_causal_padding, mode=pad_mode)
119
+ else:
120
+ assert not self.training, "The feature cache should not be used in training"
121
+ if is_init_image:
122
+ # Encode the first chunk
123
+ x = F.pad(x, self.time_causal_padding, mode=pad_mode)
124
+ self._clear_context_parallel_cache()
125
+ self.cache_front_feat.append(x[:, :, -2:].clone().detach())
126
+ else:
127
+ x = F.pad(x, self.time_uncausal_padding, mode=pad_mode)
128
+ video_front_context = self.cache_front_feat.pop()
129
+ self._clear_context_parallel_cache()
130
+
131
+ if self.temporal_stride == 1 and self.time_kernel_size == 3:
132
+ x = torch.cat([video_front_context, x], dim=2)
133
+ elif self.temporal_stride == 2 and self.time_kernel_size == 3:
134
+ x = torch.cat([video_front_context[:,:,-1:], x], dim=2)
135
+
136
+ self.cache_front_feat.append(x[:, :, -2:].clone().detach())
137
+
138
+ x = self.conv(x)
139
+ return x
video_vae/modeling_causal_vae.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Tuple, Union
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
6
+ from diffusers.models.attention_processor import (
7
+ ADDED_KV_ATTENTION_PROCESSORS,
8
+ CROSS_ATTENTION_PROCESSORS,
9
+ Attention,
10
+ AttentionProcessor,
11
+ AttnAddedKVProcessor,
12
+ AttnProcessor,
13
+ )
14
+
15
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
16
+ from diffusers.models.modeling_utils import ModelMixin
17
+
18
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
19
+ from .modeling_enc_dec import (
20
+ DecoderOutput, DiagonalGaussianDistribution,
21
+ CausalVaeDecoder, CausalVaeEncoder,
22
+ )
23
+ from .modeling_causal_conv import CausalConv3d
24
+ from IPython import embed
25
+
26
+ from utils import (
27
+ is_context_parallel_initialized,
28
+ get_context_parallel_group,
29
+ get_context_parallel_world_size,
30
+ get_context_parallel_rank,
31
+ get_context_parallel_group_rank,
32
+ )
33
+
34
+ from .context_parallel_ops import (
35
+ conv_scatter_to_context_parallel_region,
36
+ conv_gather_from_context_parallel_region,
37
+ )
38
+
39
+
40
+ class CausalVideoVAE(ModelMixin, ConfigMixin):
41
+ r"""
42
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
43
+
44
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
45
+ for all models (such as downloading or saving).
46
+
47
+ Parameters:
48
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
49
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
50
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
51
+ Tuple of downsample block types.
52
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
53
+ Tuple of upsample block types.
54
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
55
+ Tuple of block output channels.
56
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
57
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
58
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
59
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
60
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
61
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
62
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
63
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
64
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
65
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
66
+ force_upcast (`bool`, *optional*, default to `True`):
67
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
68
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
69
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
70
+ """
71
+
72
+ _supports_gradient_checkpointing = True
73
+
74
+ @register_to_config
75
+ def __init__(
76
+ self,
77
+ # encoder related parameters
78
+ encoder_in_channels: int = 3,
79
+ encoder_out_channels: int = 4,
80
+ encoder_layers_per_block: Tuple[int, ...] = (2, 2, 2, 2),
81
+ encoder_down_block_types: Tuple[str, ...] = (
82
+ "DownEncoderBlockCausal3D",
83
+ "DownEncoderBlockCausal3D",
84
+ "DownEncoderBlockCausal3D",
85
+ "DownEncoderBlockCausal3D",
86
+ ),
87
+ encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
88
+ encoder_spatial_down_sample: Tuple[bool, ...] = (True, True, True, False),
89
+ encoder_temporal_down_sample: Tuple[bool, ...] = (True, True, True, False),
90
+ encoder_block_dropout: Tuple[int, ...] = (0.0, 0.0, 0.0, 0.0),
91
+ encoder_act_fn: str = "silu",
92
+ encoder_norm_num_groups: int = 32,
93
+ encoder_double_z: bool = True,
94
+ encoder_type: str = 'causal_vae_conv',
95
+ # decoder related
96
+ decoder_in_channels: int = 4,
97
+ decoder_out_channels: int = 3,
98
+ decoder_layers_per_block: Tuple[int, ...] = (3, 3, 3, 3),
99
+ decoder_up_block_types: Tuple[str, ...] = (
100
+ "UpDecoderBlockCausal3D",
101
+ "UpDecoderBlockCausal3D",
102
+ "UpDecoderBlockCausal3D",
103
+ "UpDecoderBlockCausal3D",
104
+ ),
105
+ decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
106
+ decoder_spatial_up_sample: Tuple[bool, ...] = (True, True, True, False),
107
+ decoder_temporal_up_sample: Tuple[bool, ...] = (True, True, True, False),
108
+ decoder_block_dropout: Tuple[int, ...] = (0.0, 0.0, 0.0, 0.0),
109
+ decoder_act_fn: str = "silu",
110
+ decoder_norm_num_groups: int = 32,
111
+ decoder_type: str = 'causal_vae_conv',
112
+ sample_size: int = 256,
113
+ scaling_factor: float = 0.18215,
114
+ add_post_quant_conv: bool = True,
115
+ interpolate: bool = False,
116
+ downsample_scale: int = 8,
117
+ ):
118
+ super().__init__()
119
+
120
+ print(f"The latent dimmension channes is {encoder_out_channels}")
121
+ # pass init params to Encoder
122
+
123
+ self.encoder = CausalVaeEncoder(
124
+ in_channels=encoder_in_channels,
125
+ out_channels=encoder_out_channels,
126
+ down_block_types=encoder_down_block_types,
127
+ spatial_down_sample=encoder_spatial_down_sample,
128
+ temporal_down_sample=encoder_temporal_down_sample,
129
+ block_out_channels=encoder_block_out_channels,
130
+ layers_per_block=encoder_layers_per_block,
131
+ act_fn=encoder_act_fn,
132
+ norm_num_groups=encoder_norm_num_groups,
133
+ double_z=True,
134
+ block_dropout=encoder_block_dropout,
135
+ )
136
+
137
+ # pass init params to Decoder
138
+ self.decoder = CausalVaeDecoder(
139
+ in_channels=decoder_in_channels,
140
+ out_channels=decoder_out_channels,
141
+ up_block_types=decoder_up_block_types,
142
+ spatial_up_sample=decoder_spatial_up_sample,
143
+ temporal_up_sample=decoder_temporal_up_sample,
144
+ block_out_channels=decoder_block_out_channels,
145
+ layers_per_block=decoder_layers_per_block,
146
+ norm_num_groups=decoder_norm_num_groups,
147
+ act_fn=decoder_act_fn,
148
+ interpolate=interpolate,
149
+ block_dropout=decoder_block_dropout,
150
+ )
151
+
152
+ self.quant_conv = CausalConv3d(2 * encoder_out_channels, 2 * encoder_out_channels, kernel_size=1, stride=1)
153
+ self.post_quant_conv = CausalConv3d(encoder_out_channels, encoder_out_channels, kernel_size=1, stride=1)
154
+ self.use_tiling = False
155
+
156
+ # only relevant if vae tiling is enabled
157
+ self.tile_sample_min_size = self.config.sample_size
158
+
159
+ sample_size = (
160
+ self.config.sample_size[0]
161
+ if isinstance(self.config.sample_size, (list, tuple))
162
+ else self.config.sample_size
163
+ )
164
+ self.tile_latent_min_size = int(sample_size / downsample_scale)
165
+ self.encode_tile_overlap_factor = 1 / 8
166
+ self.decode_tile_overlap_factor = 1 / 8
167
+ self.downsample_scale = downsample_scale
168
+
169
+ self.apply(self._init_weights)
170
+
171
+ def _init_weights(self, m):
172
+ if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
173
+ trunc_normal_(m.weight, std=.02)
174
+ if m.bias is not None:
175
+ nn.init.constant_(m.bias, 0)
176
+ elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
177
+ nn.init.constant_(m.bias, 0)
178
+ nn.init.constant_(m.weight, 1.0)
179
+
180
+ def _set_gradient_checkpointing(self, module, value=False):
181
+ if isinstance(module, (Encoder, Decoder)):
182
+ module.gradient_checkpointing = value
183
+
184
+ def enable_tiling(self, use_tiling: bool = True):
185
+ r"""
186
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
187
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
188
+ processing larger images.
189
+ """
190
+ self.use_tiling = use_tiling
191
+
192
+ def disable_tiling(self):
193
+ r"""
194
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
195
+ decoding in one step.
196
+ """
197
+ self.enable_tiling(False)
198
+
199
+ @property
200
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
201
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
202
+ r"""
203
+ Returns:
204
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
205
+ indexed by its weight name.
206
+ """
207
+ # set recursively
208
+ processors = {}
209
+
210
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
211
+ if hasattr(module, "get_processor"):
212
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
213
+
214
+ for sub_name, child in module.named_children():
215
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
216
+
217
+ return processors
218
+
219
+ for name, module in self.named_children():
220
+ fn_recursive_add_processors(name, module, processors)
221
+
222
+ return processors
223
+
224
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
225
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
226
+ r"""
227
+ Sets the attention processor to use to compute attention.
228
+
229
+ Parameters:
230
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
231
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
232
+ for **all** `Attention` layers.
233
+
234
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
235
+ processor. This is strongly recommended when setting trainable attention processors.
236
+
237
+ """
238
+ count = len(self.attn_processors.keys())
239
+
240
+ if isinstance(processor, dict) and len(processor) != count:
241
+ raise ValueError(
242
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
243
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
244
+ )
245
+
246
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
247
+ if hasattr(module, "set_processor"):
248
+ if not isinstance(processor, dict):
249
+ module.set_processor(processor)
250
+ else:
251
+ module.set_processor(processor.pop(f"{name}.processor"))
252
+
253
+ for sub_name, child in module.named_children():
254
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
255
+
256
+ for name, module in self.named_children():
257
+ fn_recursive_attn_processor(name, module, processor)
258
+
259
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
260
+ def set_default_attn_processor(self):
261
+ """
262
+ Disables custom attention processors and sets the default attention implementation.
263
+ """
264
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
265
+ processor = AttnAddedKVProcessor()
266
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
267
+ processor = AttnProcessor()
268
+ else:
269
+ raise ValueError(
270
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
271
+ )
272
+
273
+ self.set_attn_processor(processor)
274
+
275
+ def encode(
276
+ self, x: torch.FloatTensor, return_dict: bool = True,
277
+ is_init_image=True, temporal_chunk=False, window_size=16, tile_sample_min_size=256,
278
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
279
+ """
280
+ Encode a batch of images into latents.
281
+
282
+ Args:
283
+ x (`torch.FloatTensor`): Input batch of images.
284
+ return_dict (`bool`, *optional*, defaults to `True`):
285
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
286
+
287
+ Returns:
288
+ The latent representations of the encoded images. If `return_dict` is True, a
289
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
290
+ """
291
+ self.tile_sample_min_size = tile_sample_min_size
292
+ self.tile_latent_min_size = int(tile_sample_min_size / self.downsample_scale)
293
+
294
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
295
+ return self.tiled_encode(x, return_dict=return_dict, is_init_image=is_init_image,
296
+ temporal_chunk=temporal_chunk, window_size=window_size)
297
+
298
+ if temporal_chunk:
299
+ moments = self.chunk_encode(x, window_size=window_size)
300
+ else:
301
+ h = self.encoder(x, is_init_image=is_init_image, temporal_chunk=False)
302
+ moments = self.quant_conv(h, is_init_image=is_init_image, temporal_chunk=False)
303
+
304
+ posterior = DiagonalGaussianDistribution(moments)
305
+
306
+ if not return_dict:
307
+ return (posterior,)
308
+
309
+ return AutoencoderKLOutput(latent_dist=posterior)
310
+
311
+ @torch.no_grad()
312
+ def chunk_encode(self, x: torch.FloatTensor, window_size=16):
313
+ # Only used during inference
314
+ # Encode a long video clips through sliding window
315
+ num_frames = x.shape[2]
316
+ assert (num_frames - 1) % self.downsample_scale == 0
317
+ init_window_size = window_size + 1
318
+ frame_list = [x[:,:,:init_window_size]]
319
+
320
+ # To chunk the long video
321
+ full_chunk_size = (num_frames - init_window_size) // window_size
322
+ fid = init_window_size
323
+ for idx in range(full_chunk_size):
324
+ frame_list.append(x[:, :, fid:fid+window_size])
325
+ fid += window_size
326
+
327
+ if fid < num_frames:
328
+ frame_list.append(x[:, :, fid:])
329
+
330
+ latent_list = []
331
+ for idx, frames in enumerate(frame_list):
332
+ if idx == 0:
333
+ h = self.encoder(frames, is_init_image=True, temporal_chunk=True)
334
+ moments = self.quant_conv(h, is_init_image=True, temporal_chunk=True)
335
+ else:
336
+ h = self.encoder(frames, is_init_image=False, temporal_chunk=True)
337
+ moments = self.quant_conv(h, is_init_image=False, temporal_chunk=True)
338
+
339
+ latent_list.append(moments)
340
+
341
+ latent = torch.cat(latent_list, dim=2)
342
+ return latent
343
+
344
+ def get_last_layer(self):
345
+ return self.decoder.conv_out.conv.weight
346
+
347
+ @torch.no_grad()
348
+ def chunk_decode(self, z: torch.FloatTensor, window_size=2):
349
+ num_frames = z.shape[2]
350
+ init_window_size = window_size + 1
351
+ frame_list = [z[:,:,:init_window_size]]
352
+
353
+ # To chunk the long video
354
+ full_chunk_size = (num_frames - init_window_size) // window_size
355
+ fid = init_window_size
356
+ for idx in range(full_chunk_size):
357
+ frame_list.append(z[:, :, fid:fid+window_size])
358
+ fid += window_size
359
+
360
+ if fid < num_frames:
361
+ frame_list.append(z[:, :, fid:])
362
+
363
+ dec_list = []
364
+ for idx, frames in enumerate(frame_list):
365
+ if idx == 0:
366
+ z_h = self.post_quant_conv(frames, is_init_image=True, temporal_chunk=True)
367
+ dec = self.decoder(z_h, is_init_image=True, temporal_chunk=True)
368
+ else:
369
+ z_h = self.post_quant_conv(frames, is_init_image=False, temporal_chunk=True)
370
+ dec = self.decoder(z_h, is_init_image=False, temporal_chunk=True)
371
+
372
+ dec_list.append(dec)
373
+
374
+ dec = torch.cat(dec_list, dim=2)
375
+ return dec
376
+
377
+ def decode(self, z: torch.FloatTensor, is_init_image=True, temporal_chunk=False,
378
+ return_dict: bool = True, window_size: int = 2, tile_sample_min_size: int = 256,) -> Union[DecoderOutput, torch.FloatTensor]:
379
+
380
+ self.tile_sample_min_size = tile_sample_min_size
381
+ self.tile_latent_min_size = int(tile_sample_min_size / self.downsample_scale)
382
+
383
+ if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
384
+ return self.tiled_decode(z, is_init_image=is_init_image,
385
+ temporal_chunk=temporal_chunk, window_size=window_size, return_dict=return_dict)
386
+
387
+ if temporal_chunk:
388
+ dec = self.chunk_decode(z, window_size=window_size)
389
+ else:
390
+ z = self.post_quant_conv(z, is_init_image=is_init_image, temporal_chunk=False)
391
+ dec = self.decoder(z, is_init_image=is_init_image, temporal_chunk=False)
392
+
393
+ if not return_dict:
394
+ return (dec,)
395
+
396
+ return DecoderOutput(sample=dec)
397
+
398
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
399
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
400
+ for y in range(blend_extent):
401
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
402
+ return b
403
+
404
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
405
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
406
+ for x in range(blend_extent):
407
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
408
+ return b
409
+
410
+ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True,
411
+ is_init_image=True, temporal_chunk=False, window_size=16,) -> AutoencoderKLOutput:
412
+ r"""Encode a batch of images using a tiled encoder.
413
+
414
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
415
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
416
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
417
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
418
+ output, but they should be much less noticeable.
419
+
420
+ Args:
421
+ x (`torch.FloatTensor`): Input batch of images.
422
+ return_dict (`bool`, *optional*, defaults to `True`):
423
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
424
+
425
+ Returns:
426
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
427
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
428
+ `tuple` is returned.
429
+ """
430
+ overlap_size = int(self.tile_sample_min_size * (1 - self.encode_tile_overlap_factor))
431
+ blend_extent = int(self.tile_latent_min_size * self.encode_tile_overlap_factor)
432
+ row_limit = self.tile_latent_min_size - blend_extent
433
+
434
+ # Split the image into 512x512 tiles and encode them separately.
435
+ rows = []
436
+ for i in range(0, x.shape[3], overlap_size):
437
+ row = []
438
+ for j in range(0, x.shape[4], overlap_size):
439
+ tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
440
+ if temporal_chunk:
441
+ tile = self.chunk_encode(tile, window_size=window_size)
442
+ else:
443
+ tile = self.encoder(tile, is_init_image=True, temporal_chunk=False)
444
+ tile = self.quant_conv(tile, is_init_image=True, temporal_chunk=False)
445
+ row.append(tile)
446
+ rows.append(row)
447
+ result_rows = []
448
+ for i, row in enumerate(rows):
449
+ result_row = []
450
+ for j, tile in enumerate(row):
451
+ # blend the above tile and the left tile
452
+ # to the current tile and add the current tile to the result row
453
+ if i > 0:
454
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
455
+ if j > 0:
456
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
457
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
458
+ result_rows.append(torch.cat(result_row, dim=4))
459
+
460
+ moments = torch.cat(result_rows, dim=3)
461
+
462
+ posterior = DiagonalGaussianDistribution(moments)
463
+
464
+ if not return_dict:
465
+ return (posterior,)
466
+
467
+ return AutoencoderKLOutput(latent_dist=posterior)
468
+
469
+ def tiled_decode(self, z: torch.FloatTensor, is_init_image=True,
470
+ temporal_chunk=False, window_size=2, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
471
+ r"""
472
+ Decode a batch of images using a tiled decoder.
473
+
474
+ Args:
475
+ z (`torch.FloatTensor`): Input batch of latent vectors.
476
+ return_dict (`bool`, *optional*, defaults to `True`):
477
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
478
+
479
+ Returns:
480
+ [`~models.vae.DecoderOutput`] or `tuple`:
481
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
482
+ returned.
483
+ """
484
+ overlap_size = int(self.tile_latent_min_size * (1 - self.decode_tile_overlap_factor))
485
+ blend_extent = int(self.tile_sample_min_size * self.decode_tile_overlap_factor)
486
+ row_limit = self.tile_sample_min_size - blend_extent
487
+
488
+ # Split z into overlapping 64x64 tiles and decode them separately.
489
+ # The tiles have an overlap to avoid seams between tiles.
490
+ rows = []
491
+ for i in range(0, z.shape[3], overlap_size):
492
+ row = []
493
+ for j in range(0, z.shape[4], overlap_size):
494
+ tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
495
+ if temporal_chunk:
496
+ decoded = self.chunk_decode(tile, window_size=window_size)
497
+ else:
498
+ tile = self.post_quant_conv(tile, is_init_image=True, temporal_chunk=False)
499
+ decoded = self.decoder(tile, is_init_image=True, temporal_chunk=False)
500
+ row.append(decoded)
501
+ rows.append(row)
502
+ result_rows = []
503
+
504
+ for i, row in enumerate(rows):
505
+ result_row = []
506
+ for j, tile in enumerate(row):
507
+ # blend the above tile and the left tile
508
+ # to the current tile and add the current tile to the result row
509
+ if i > 0:
510
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
511
+ if j > 0:
512
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
513
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
514
+ result_rows.append(torch.cat(result_row, dim=4))
515
+
516
+ dec = torch.cat(result_rows, dim=3)
517
+ if not return_dict:
518
+ return (dec,)
519
+
520
+ return DecoderOutput(sample=dec)
521
+
522
+ def forward(
523
+ self,
524
+ sample: torch.FloatTensor,
525
+ sample_posterior: bool = True,
526
+ generator: Optional[torch.Generator] = None,
527
+ freeze_encoder: bool = False,
528
+ is_init_image=True,
529
+ temporal_chunk=False,
530
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
531
+ r"""
532
+ Args:
533
+ sample (`torch.FloatTensor`): Input sample.
534
+ sample_posterior (`bool`, *optional*, defaults to `False`):
535
+ Whether to sample from the posterior.
536
+ return_dict (`bool`, *optional*, defaults to `True`):
537
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
538
+ """
539
+ x = sample
540
+
541
+ if is_context_parallel_initialized():
542
+ assert self.training, "Only supports during training now"
543
+
544
+ if freeze_encoder:
545
+ with torch.no_grad():
546
+ h = self.encoder(x, is_init_image=True, temporal_chunk=False)
547
+ moments = self.quant_conv(h, is_init_image=True, temporal_chunk=False)
548
+ posterior = DiagonalGaussianDistribution(moments)
549
+ global_posterior = posterior
550
+ else:
551
+ h = self.encoder(x, is_init_image=True, temporal_chunk=False)
552
+ moments = self.quant_conv(h, is_init_image=True, temporal_chunk=False)
553
+ posterior = DiagonalGaussianDistribution(moments)
554
+ global_moments = conv_gather_from_context_parallel_region(moments, dim=2, kernel_size=1)
555
+ global_posterior = DiagonalGaussianDistribution(global_moments)
556
+
557
+ if sample_posterior:
558
+ z = posterior.sample(generator=generator)
559
+ else:
560
+ z = posterior.mode()
561
+
562
+ if get_context_parallel_rank() == 0:
563
+ dec = self.decode(z, is_init_image=True).sample
564
+ else:
565
+ # Do not drop the first upsampled frame
566
+ dec = self.decode(z, is_init_image=False).sample
567
+
568
+ return global_posterior, dec
569
+
570
+ else:
571
+ # The normal training
572
+ if freeze_encoder:
573
+ with torch.no_grad():
574
+ posterior = self.encode(x, is_init_image=is_init_image,
575
+ temporal_chunk=temporal_chunk).latent_dist
576
+ else:
577
+ posterior = self.encode(x, is_init_image=is_init_image,
578
+ temporal_chunk=temporal_chunk).latent_dist
579
+
580
+ if sample_posterior:
581
+ z = posterior.sample(generator=generator)
582
+ else:
583
+ z = posterior.mode()
584
+
585
+ dec = self.decode(z, is_init_image=is_init_image, temporal_chunk=temporal_chunk).sample
586
+
587
+ return posterior, dec
588
+
589
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
590
+ def fuse_qkv_projections(self):
591
+ """
592
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
593
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
594
+
595
+ <Tip warning={true}>
596
+
597
+ This API is 🧪 experimental.
598
+
599
+ </Tip>
600
+ """
601
+ self.original_attn_processors = None
602
+
603
+ for _, attn_processor in self.attn_processors.items():
604
+ if "Added" in str(attn_processor.__class__.__name__):
605
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
606
+
607
+ self.original_attn_processors = self.attn_processors
608
+
609
+ for module in self.modules():
610
+ if isinstance(module, Attention):
611
+ module.fuse_projections(fuse=True)
612
+
613
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
614
+ def unfuse_qkv_projections(self):
615
+ """Disables the fused QKV projection if enabled.
616
+
617
+ <Tip warning={true}>
618
+
619
+ This API is 🧪 experimental.
620
+
621
+ </Tip>
622
+
623
+ """
624
+ if self.original_attn_processors is not None:
625
+ self.set_attn_processor(self.original_attn_processors)
video_vae/modeling_discriminator.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch.nn as nn
3
+ from einops import rearrange
4
+ import torch
5
+
6
+
7
+ def weights_init(m):
8
+ classname = m.__class__.__name__
9
+ if classname.find('Conv') != -1:
10
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
11
+ nn.init.constant_(m.bias.data, 0)
12
+ elif classname.find('BatchNorm') != -1:
13
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
14
+ nn.init.constant_(m.bias.data, 0)
15
+
16
+
17
+ class NLayerDiscriminator(nn.Module):
18
+ """Defines a PatchGAN discriminator as in Pix2Pix
19
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20
+ """
21
+ def __init__(self, input_nc=3, ndf=64, n_layers=4):
22
+ """Construct a PatchGAN discriminator
23
+ Parameters:
24
+ input_nc (int) -- the number of channels in input images
25
+ ndf (int) -- the number of filters in the last conv layer
26
+ n_layers (int) -- the number of conv layers in the discriminator
27
+ norm_layer -- normalization layer
28
+ """
29
+ super(NLayerDiscriminator, self).__init__()
30
+
31
+ # norm_layer = nn.BatchNorm2d
32
+ norm_layer = nn.InstanceNorm2d
33
+
34
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
35
+ use_bias = norm_layer.func != nn.BatchNorm2d
36
+ else:
37
+ use_bias = norm_layer != nn.BatchNorm2d
38
+
39
+ kw = 4
40
+ padw = 1
41
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
42
+ nf_mult = 1
43
+ nf_mult_prev = 1
44
+ for n in range(1, n_layers): # gradually increase the number of filters
45
+ nf_mult_prev = nf_mult
46
+ nf_mult = min(2 ** n, 8)
47
+ sequence += [
48
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
49
+ norm_layer(ndf * nf_mult),
50
+ nn.LeakyReLU(0.2, True)
51
+ ]
52
+
53
+ nf_mult_prev = nf_mult
54
+ nf_mult = min(2 ** n_layers, 8)
55
+ sequence += [
56
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
57
+ norm_layer(ndf * nf_mult),
58
+ nn.LeakyReLU(0.2, True)
59
+ ]
60
+
61
+ sequence += [
62
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
63
+ self.main = nn.Sequential(*sequence)
64
+
65
+ def forward(self, input):
66
+ """Standard forward."""
67
+ return self.main(input)
68
+
69
+
70
+ class NLayerDiscriminator3D(nn.Module):
71
+ """Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs."""
72
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
73
+ """
74
+ Construct a 3D PatchGAN discriminator
75
+
76
+ Parameters:
77
+ input_nc (int) -- the number of channels in input volumes
78
+ ndf (int) -- the number of filters in the last conv layer
79
+ n_layers (int) -- the number of conv layers in the discriminator
80
+ use_actnorm (bool) -- flag to use actnorm instead of batchnorm
81
+ """
82
+ super(NLayerDiscriminator3D, self).__init__()
83
+ # if not use_actnorm:
84
+ # norm_layer = nn.BatchNorm3d
85
+ # else:
86
+ # raise NotImplementedError("Not implemented.")
87
+
88
+ norm_layer = nn.InstanceNorm3d
89
+
90
+ if type(norm_layer) == functools.partial:
91
+ use_bias = norm_layer.func != nn.BatchNorm3d
92
+ else:
93
+ use_bias = norm_layer != nn.BatchNorm3d
94
+
95
+ kw = 4
96
+ padw = 1
97
+ sequence = [nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
98
+ nf_mult = 1
99
+ nf_mult_prev = 1
100
+ for n in range(1, n_layers): # gradually increase the number of filters
101
+ nf_mult_prev = nf_mult
102
+ nf_mult = min(2 ** n, 8)
103
+ sequence += [
104
+ nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=(1,2,2), padding=padw, bias=use_bias),
105
+ norm_layer(ndf * nf_mult),
106
+ nn.LeakyReLU(0.2, True)
107
+ ]
108
+
109
+ nf_mult_prev = nf_mult
110
+ nf_mult = min(2 ** n_layers, 8)
111
+ sequence += [
112
+ nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=1, padding=padw, bias=use_bias),
113
+ norm_layer(ndf * nf_mult),
114
+ nn.LeakyReLU(0.2, True)
115
+ ]
116
+
117
+ sequence += [nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
118
+ self.main = nn.Sequential(*sequence)
119
+
120
+ def forward(self, input):
121
+ """Standard forward."""
122
+ return self.main(input)
video_vae/modeling_enc_dec.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ from einops import rearrange
21
+
22
+ from diffusers.utils import BaseOutput, is_torch_version
23
+ from diffusers.utils.torch_utils import randn_tensor
24
+ from diffusers.models.attention_processor import SpatialNorm
25
+ from .modeling_block import (
26
+ UNetMidBlock2D,
27
+ CausalUNetMidBlock2D,
28
+ get_down_block,
29
+ get_up_block,
30
+ get_input_layer,
31
+ get_output_layer,
32
+ )
33
+ from .modeling_resnet import (
34
+ Downsample2D,
35
+ Upsample2D,
36
+ TemporalDownsample2x,
37
+ TemporalUpsample2x,
38
+ )
39
+ from .modeling_causal_conv import CausalConv3d, CausalGroupNorm
40
+
41
+
42
+ @dataclass
43
+ class DecoderOutput(BaseOutput):
44
+ r"""
45
+ Output of decoding method.
46
+
47
+ Args:
48
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
49
+ The decoded output sample from the last layer of the model.
50
+ """
51
+
52
+ sample: torch.FloatTensor
53
+
54
+
55
+ class CausalVaeEncoder(nn.Module):
56
+ r"""
57
+ The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
58
+
59
+ Args:
60
+ in_channels (`int`, *optional*, defaults to 3):
61
+ The number of input channels.
62
+ out_channels (`int`, *optional*, defaults to 3):
63
+ The number of output channels.
64
+ down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
65
+ The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
66
+ options.
67
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
68
+ The number of output channels for each block.
69
+ layers_per_block (`int`, *optional*, defaults to 2):
70
+ The number of layers per block.
71
+ norm_num_groups (`int`, *optional*, defaults to 32):
72
+ The number of groups for normalization.
73
+ act_fn (`str`, *optional*, defaults to `"silu"`):
74
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
75
+ double_z (`bool`, *optional*, defaults to `True`):
76
+ Whether to double the number of output channels for the last block.
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ in_channels: int = 3,
82
+ out_channels: int = 3,
83
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
84
+ spatial_down_sample: Tuple[bool, ...] = (True,),
85
+ temporal_down_sample: Tuple[bool, ...] = (False,),
86
+ block_out_channels: Tuple[int, ...] = (64,),
87
+ layers_per_block: Tuple[int, ...] = (2,),
88
+ norm_num_groups: int = 32,
89
+ act_fn: str = "silu",
90
+ double_z: bool = True,
91
+ block_dropout: Tuple[int, ...] = (0.0,),
92
+ mid_block_add_attention=True,
93
+ ):
94
+ super().__init__()
95
+ self.layers_per_block = layers_per_block
96
+
97
+ self.conv_in = CausalConv3d(
98
+ in_channels,
99
+ block_out_channels[0],
100
+ kernel_size=3,
101
+ stride=1,
102
+ )
103
+
104
+ self.mid_block = None
105
+ self.down_blocks = nn.ModuleList([])
106
+
107
+ # down
108
+ output_channel = block_out_channels[0]
109
+ for i, down_block_type in enumerate(down_block_types):
110
+ input_channel = output_channel
111
+ output_channel = block_out_channels[i]
112
+
113
+ down_block = get_down_block(
114
+ down_block_type,
115
+ num_layers=self.layers_per_block[i],
116
+ in_channels=input_channel,
117
+ out_channels=output_channel,
118
+ add_spatial_downsample=spatial_down_sample[i],
119
+ add_temporal_downsample=temporal_down_sample[i],
120
+ resnet_eps=1e-6,
121
+ downsample_padding=0,
122
+ resnet_act_fn=act_fn,
123
+ resnet_groups=norm_num_groups,
124
+ attention_head_dim=output_channel,
125
+ temb_channels=None,
126
+ dropout=block_dropout[i],
127
+ )
128
+ self.down_blocks.append(down_block)
129
+
130
+ # mid
131
+ self.mid_block = CausalUNetMidBlock2D(
132
+ in_channels=block_out_channels[-1],
133
+ resnet_eps=1e-6,
134
+ resnet_act_fn=act_fn,
135
+ output_scale_factor=1,
136
+ resnet_time_scale_shift="default",
137
+ attention_head_dim=block_out_channels[-1],
138
+ resnet_groups=norm_num_groups,
139
+ temb_channels=None,
140
+ add_attention=mid_block_add_attention,
141
+ dropout=block_dropout[-1],
142
+ )
143
+
144
+ # out
145
+
146
+ self.conv_norm_out = CausalGroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
147
+ self.conv_act = nn.SiLU()
148
+
149
+ conv_out_channels = 2 * out_channels if double_z else out_channels
150
+ self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3, stride=1)
151
+
152
+ self.gradient_checkpointing = False
153
+
154
+ def forward(self, sample: torch.FloatTensor, is_init_image=True, temporal_chunk=False) -> torch.FloatTensor:
155
+ r"""The forward method of the `Encoder` class."""
156
+
157
+ sample = self.conv_in(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
158
+
159
+ if self.training and self.gradient_checkpointing:
160
+
161
+ def create_custom_forward(module):
162
+ def custom_forward(*inputs):
163
+ return module(*inputs)
164
+
165
+ return custom_forward
166
+
167
+ # down
168
+ if is_torch_version(">=", "1.11.0"):
169
+ for down_block in self.down_blocks:
170
+ sample = torch.utils.checkpoint.checkpoint(
171
+ create_custom_forward(down_block), sample, is_init_image,
172
+ temporal_chunk, use_reentrant=False
173
+ )
174
+ # middle
175
+ sample = torch.utils.checkpoint.checkpoint(
176
+ create_custom_forward(self.mid_block), sample, is_init_image,
177
+ temporal_chunk, use_reentrant=False
178
+ )
179
+ else:
180
+ for down_block in self.down_blocks:
181
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample, is_init_image, temporal_chunk)
182
+ # middle
183
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, is_init_image, temporal_chunk)
184
+
185
+ else:
186
+ # down
187
+ for down_block in self.down_blocks:
188
+ sample = down_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
189
+
190
+ # middle
191
+ sample = self.mid_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
192
+
193
+ # post-process
194
+ sample = self.conv_norm_out(sample)
195
+ sample = self.conv_act(sample)
196
+ sample = self.conv_out(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
197
+
198
+ return sample
199
+
200
+
201
+ class CausalVaeDecoder(nn.Module):
202
+ r"""
203
+ The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
204
+
205
+ Args:
206
+ in_channels (`int`, *optional*, defaults to 3):
207
+ The number of input channels.
208
+ out_channels (`int`, *optional*, defaults to 3):
209
+ The number of output channels.
210
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
211
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
212
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
213
+ The number of output channels for each block.
214
+ layers_per_block (`int`, *optional*, defaults to 2):
215
+ The number of layers per block.
216
+ norm_num_groups (`int`, *optional*, defaults to 32):
217
+ The number of groups for normalization.
218
+ act_fn (`str`, *optional*, defaults to `"silu"`):
219
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
220
+ norm_type (`str`, *optional*, defaults to `"group"`):
221
+ The normalization type to use. Can be either `"group"` or `"spatial"`.
222
+ """
223
+
224
+ def __init__(
225
+ self,
226
+ in_channels: int = 3,
227
+ out_channels: int = 3,
228
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
229
+ spatial_up_sample: Tuple[bool, ...] = (True,),
230
+ temporal_up_sample: Tuple[bool, ...] = (False,),
231
+ block_out_channels: Tuple[int, ...] = (64,),
232
+ layers_per_block: Tuple[int, ...] = (2,),
233
+ norm_num_groups: int = 32,
234
+ act_fn: str = "silu",
235
+ mid_block_add_attention=True,
236
+ interpolate: bool = True,
237
+ block_dropout: Tuple[int, ...] = (0.0,),
238
+ ):
239
+ super().__init__()
240
+ self.layers_per_block = layers_per_block
241
+
242
+ self.conv_in = CausalConv3d(
243
+ in_channels,
244
+ block_out_channels[-1],
245
+ kernel_size=3,
246
+ stride=1,
247
+ )
248
+
249
+ self.mid_block = None
250
+ self.up_blocks = nn.ModuleList([])
251
+
252
+ # mid
253
+ self.mid_block = CausalUNetMidBlock2D(
254
+ in_channels=block_out_channels[-1],
255
+ resnet_eps=1e-6,
256
+ resnet_act_fn=act_fn,
257
+ output_scale_factor=1,
258
+ resnet_time_scale_shift="default",
259
+ attention_head_dim=block_out_channels[-1],
260
+ resnet_groups=norm_num_groups,
261
+ temb_channels=None,
262
+ add_attention=mid_block_add_attention,
263
+ dropout=block_dropout[-1],
264
+ )
265
+
266
+ # up
267
+ reversed_block_out_channels = list(reversed(block_out_channels))
268
+ output_channel = reversed_block_out_channels[0]
269
+ for i, up_block_type in enumerate(up_block_types):
270
+ prev_output_channel = output_channel
271
+ output_channel = reversed_block_out_channels[i]
272
+
273
+ is_final_block = i == len(block_out_channels) - 1
274
+
275
+ up_block = get_up_block(
276
+ up_block_type,
277
+ num_layers=self.layers_per_block[i],
278
+ in_channels=prev_output_channel,
279
+ out_channels=output_channel,
280
+ prev_output_channel=None,
281
+ add_spatial_upsample=spatial_up_sample[i],
282
+ add_temporal_upsample=temporal_up_sample[i],
283
+ resnet_eps=1e-6,
284
+ resnet_act_fn=act_fn,
285
+ resnet_groups=norm_num_groups,
286
+ attention_head_dim=output_channel,
287
+ temb_channels=None,
288
+ resnet_time_scale_shift='default',
289
+ interpolate=interpolate,
290
+ dropout=block_dropout[i],
291
+ )
292
+ self.up_blocks.append(up_block)
293
+ prev_output_channel = output_channel
294
+
295
+ # out
296
+ self.conv_norm_out = CausalGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
297
+ self.conv_act = nn.SiLU()
298
+ self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3, stride=1)
299
+
300
+ self.gradient_checkpointing = False
301
+
302
+ def forward(
303
+ self,
304
+ sample: torch.FloatTensor,
305
+ is_init_image=True,
306
+ temporal_chunk=False,
307
+ ) -> torch.FloatTensor:
308
+ r"""The forward method of the `Decoder` class."""
309
+
310
+ sample = self.conv_in(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
311
+
312
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
313
+ if self.training and self.gradient_checkpointing:
314
+
315
+ def create_custom_forward(module):
316
+ def custom_forward(*inputs):
317
+ return module(*inputs)
318
+
319
+ return custom_forward
320
+
321
+ if is_torch_version(">=", "1.11.0"):
322
+ # middle
323
+ sample = torch.utils.checkpoint.checkpoint(
324
+ create_custom_forward(self.mid_block),
325
+ sample,
326
+ is_init_image=is_init_image,
327
+ temporal_chunk=temporal_chunk,
328
+ use_reentrant=False,
329
+ )
330
+ sample = sample.to(upscale_dtype)
331
+
332
+ # up
333
+ for up_block in self.up_blocks:
334
+ sample = torch.utils.checkpoint.checkpoint(
335
+ create_custom_forward(up_block),
336
+ sample,
337
+ is_init_image=is_init_image,
338
+ temporal_chunk=temporal_chunk,
339
+ use_reentrant=False,
340
+ )
341
+ else:
342
+ # middle
343
+ sample = torch.utils.checkpoint.checkpoint(
344
+ create_custom_forward(self.mid_block), sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk,
345
+ )
346
+ sample = sample.to(upscale_dtype)
347
+
348
+ # up
349
+ for up_block in self.up_blocks:
350
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample,
351
+ is_init_image=is_init_image, temporal_chunk=temporal_chunk,)
352
+ else:
353
+ # middle
354
+ sample = self.mid_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
355
+ sample = sample.to(upscale_dtype)
356
+
357
+ # up
358
+ for up_block in self.up_blocks:
359
+ sample = up_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk,)
360
+
361
+ # post-process
362
+ sample = self.conv_norm_out(sample)
363
+ sample = self.conv_act(sample)
364
+ sample = self.conv_out(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
365
+
366
+ return sample
367
+
368
+
369
+ class DiagonalGaussianDistribution(object):
370
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
371
+ self.parameters = parameters
372
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
373
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
374
+ self.deterministic = deterministic
375
+ self.std = torch.exp(0.5 * self.logvar)
376
+ self.var = torch.exp(self.logvar)
377
+ if self.deterministic:
378
+ self.var = self.std = torch.zeros_like(
379
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
380
+ )
381
+
382
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
383
+ # make sure sample is on the same device as the parameters and has same dtype
384
+ sample = randn_tensor(
385
+ self.mean.shape,
386
+ generator=generator,
387
+ device=self.parameters.device,
388
+ dtype=self.parameters.dtype,
389
+ )
390
+ x = self.mean + self.std * sample
391
+ return x
392
+
393
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
394
+ if self.deterministic:
395
+ return torch.Tensor([0.0])
396
+ else:
397
+ if other is None:
398
+ return 0.5 * torch.sum(
399
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
400
+ dim=[2, 3, 4],
401
+ )
402
+ else:
403
+ return 0.5 * torch.sum(
404
+ torch.pow(self.mean - other.mean, 2) / other.var
405
+ + self.var / other.var
406
+ - 1.0
407
+ - self.logvar
408
+ + other.logvar,
409
+ dim=[2, 3, 4],
410
+ )
411
+
412
+ def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
413
+ if self.deterministic:
414
+ return torch.Tensor([0.0])
415
+ logtwopi = np.log(2.0 * np.pi)
416
+ return 0.5 * torch.sum(
417
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
418
+ dim=dims,
419
+ )
420
+
421
+ def mode(self) -> torch.Tensor:
422
+ return self.mean
video_vae/modeling_loss.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from .modeling_lpips import LPIPS
7
+ from .modeling_discriminator import NLayerDiscriminator, NLayerDiscriminator3D, weights_init
8
+ from IPython import embed
9
+
10
+
11
+ class AdaptiveLossWeight:
12
+ def __init__(self, timestep_range=[0, 1], buckets=300, weight_range=[1e-7, 1e7]):
13
+ self.bucket_ranges = torch.linspace(timestep_range[0], timestep_range[1], buckets-1)
14
+ self.bucket_losses = torch.ones(buckets)
15
+ self.weight_range = weight_range
16
+
17
+ def weight(self, timestep):
18
+ indices = torch.searchsorted(self.bucket_ranges.to(timestep.device), timestep)
19
+ return (1/self.bucket_losses.to(timestep.device)[indices]).clamp(*self.weight_range)
20
+
21
+ def update_buckets(self, timestep, loss, beta=0.99):
22
+ indices = torch.searchsorted(self.bucket_ranges.to(timestep.device), timestep).cpu()
23
+ self.bucket_losses[indices] = self.bucket_losses[indices]*beta + loss.detach().cpu() * (1-beta)
24
+
25
+
26
+ def hinge_d_loss(logits_real, logits_fake):
27
+ loss_real = torch.mean(F.relu(1.0 - logits_real))
28
+ loss_fake = torch.mean(F.relu(1.0 + logits_fake))
29
+ d_loss = 0.5 * (loss_real + loss_fake)
30
+ return d_loss
31
+
32
+
33
+ def vanilla_d_loss(logits_real, logits_fake):
34
+ d_loss = 0.5 * (
35
+ torch.mean(torch.nn.functional.softplus(-logits_real))
36
+ + torch.mean(torch.nn.functional.softplus(logits_fake))
37
+ )
38
+ return d_loss
39
+
40
+
41
+ def adopt_weight(weight, global_step, threshold=0, value=0.0):
42
+ if global_step < threshold:
43
+ weight = value
44
+ return weight
45
+
46
+
47
+ class LPIPSWithDiscriminator(nn.Module):
48
+ def __init__(
49
+ self,
50
+ disc_start,
51
+ logvar_init=0.0,
52
+ kl_weight=1.0,
53
+ pixelloss_weight=1.0,
54
+ perceptual_weight=1.0,
55
+ # --- Discriminator Loss ---
56
+ disc_num_layers=4,
57
+ disc_in_channels=3,
58
+ disc_factor=1.0,
59
+ disc_weight=0.5,
60
+ disc_loss="hinge",
61
+ add_discriminator=True,
62
+ using_3d_discriminator=False,
63
+ ):
64
+
65
+ super().__init__()
66
+ assert disc_loss in ["hinge", "vanilla"]
67
+ self.kl_weight = kl_weight
68
+ self.pixel_weight = pixelloss_weight
69
+ self.perceptual_loss = LPIPS().eval()
70
+ self.perceptual_weight = perceptual_weight
71
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
72
+
73
+ if add_discriminator:
74
+ disc_cls = NLayerDiscriminator3D if using_3d_discriminator else NLayerDiscriminator
75
+ self.discriminator = disc_cls(
76
+ input_nc=disc_in_channels, n_layers=disc_num_layers,
77
+ ).apply(weights_init)
78
+ else:
79
+ self.discriminator = None
80
+
81
+ self.discriminator_iter_start = disc_start
82
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
83
+ self.disc_factor = disc_factor
84
+ self.discriminator_weight = disc_weight
85
+ self.using_3d_discriminator = using_3d_discriminator
86
+
87
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
88
+ if last_layer is not None:
89
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
90
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
91
+ else:
92
+ nll_grads = torch.autograd.grad(
93
+ nll_loss, self.last_layer[0], retain_graph=True
94
+ )[0]
95
+ g_grads = torch.autograd.grad(
96
+ g_loss, self.last_layer[0], retain_graph=True
97
+ )[0]
98
+
99
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
100
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
101
+ d_weight = d_weight * self.discriminator_weight
102
+ return d_weight
103
+
104
+ def forward(
105
+ self,
106
+ inputs,
107
+ reconstructions,
108
+ posteriors,
109
+ optimizer_idx,
110
+ global_step,
111
+ split="train",
112
+ last_layer=None,
113
+ ):
114
+ t = reconstructions.shape[2]
115
+ inputs = rearrange(inputs, "b c t h w -> (b t) c h w").contiguous()
116
+ reconstructions = rearrange(reconstructions, "b c t h w -> (b t) c h w").contiguous()
117
+
118
+ if optimizer_idx == 0:
119
+ # rec_loss = torch.mean(torch.abs(inputs - reconstructions), dim=(1,2,3), keepdim=True)
120
+ rec_loss = torch.mean(F.mse_loss(inputs, reconstructions, reduction='none'), dim=(1,2,3), keepdim=True)
121
+
122
+ if self.perceptual_weight > 0:
123
+ p_loss = self.perceptual_loss(inputs, reconstructions)
124
+ nll_loss = self.pixel_weight * rec_loss + self.perceptual_weight * p_loss
125
+
126
+ nll_loss = nll_loss / torch.exp(self.logvar) + self.logvar
127
+ weighted_nll_loss = nll_loss
128
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
129
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
130
+
131
+ kl_loss = posteriors.kl()
132
+ kl_loss = torch.mean(kl_loss)
133
+
134
+ disc_factor = adopt_weight(
135
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
136
+ )
137
+
138
+ if disc_factor > 0.0:
139
+ if self.using_3d_discriminator:
140
+ reconstructions = rearrange(reconstructions, '(b t) c h w -> b c t h w', t=t)
141
+
142
+ logits_fake = self.discriminator(reconstructions.contiguous())
143
+ g_loss = -torch.mean(logits_fake)
144
+ try:
145
+ d_weight = self.calculate_adaptive_weight(
146
+ nll_loss, g_loss, last_layer=last_layer
147
+ )
148
+ except RuntimeError:
149
+ assert not self.training
150
+ d_weight = torch.tensor(0.0)
151
+ else:
152
+ d_weight = torch.tensor(0.0)
153
+ g_loss = torch.tensor(0.0)
154
+
155
+
156
+ loss = (
157
+ weighted_nll_loss
158
+ + self.kl_weight * kl_loss
159
+ + d_weight * disc_factor * g_loss
160
+ )
161
+ log = {
162
+ "{}/total_loss".format(split): loss.clone().detach().mean(),
163
+ "{}/logvar".format(split): self.logvar.detach(),
164
+ "{}/kl_loss".format(split): kl_loss.detach().mean(),
165
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
166
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
167
+ "{}/perception_loss".format(split): p_loss.detach().mean(),
168
+ "{}/d_weight".format(split): d_weight.detach(),
169
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
170
+ "{}/g_loss".format(split): g_loss.detach().mean(),
171
+ }
172
+ return loss, log
173
+
174
+ if optimizer_idx == 1:
175
+ if self.using_3d_discriminator:
176
+ inputs = rearrange(inputs, '(b t) c h w -> b c t h w', t=t)
177
+ reconstructions = rearrange(reconstructions, '(b t) c h w -> b c t h w', t=t)
178
+
179
+ logits_real = self.discriminator(inputs.contiguous().detach())
180
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
181
+
182
+ disc_factor = adopt_weight(
183
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
184
+ )
185
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
186
+
187
+ log = {
188
+ "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
189
+ "{}/logits_real".format(split): logits_real.detach().mean(),
190
+ "{}/logits_fake".format(split): logits_fake.detach().mean(),
191
+ }
192
+ return d_loss, log
video_vae/modeling_lpips.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torchvision import models
6
+ from collections import namedtuple
7
+
8
+
9
+ class LPIPS(nn.Module):
10
+ # Learned perceptual metric
11
+ def __init__(self, use_dropout=True):
12
+ super().__init__()
13
+ self.scaling_layer = ScalingLayer()
14
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
15
+ self.net = vgg16(pretrained=False, requires_grad=False)
16
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
17
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
18
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
19
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
20
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
21
+ self.load_from_pretrained()
22
+ for param in self.parameters():
23
+ param.requires_grad = False
24
+
25
+ def load_from_pretrained(self):
26
+ ckpt = "/home/jinyang/models/vae/video_vae_baseline/vgg_lpips.pth" # replace with your lpips
27
+ self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=True)
28
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
29
+
30
+ def forward(self, input, target):
31
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
32
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
33
+ feats0, feats1, diffs = {}, {}, {}
34
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
35
+ for kk in range(len(self.chns)):
36
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
37
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
38
+
39
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
40
+ val = res[0]
41
+ for l in range(1, len(self.chns)):
42
+ val += res[l]
43
+ return val
44
+
45
+
46
+ class ScalingLayer(nn.Module):
47
+ def __init__(self):
48
+ super(ScalingLayer, self).__init__()
49
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
50
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
51
+
52
+ def forward(self, inp):
53
+ return (inp - self.shift) / self.scale
54
+
55
+
56
+ class NetLinLayer(nn.Module):
57
+ """ A single linear layer which does a 1x1 conv """
58
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
59
+ super(NetLinLayer, self).__init__()
60
+ layers = [nn.Dropout(), ] if (use_dropout) else []
61
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
62
+ self.model = nn.Sequential(*layers)
63
+
64
+
65
+ class vgg16(torch.nn.Module):
66
+ def __init__(self, requires_grad=False, pretrained=True):
67
+ super(vgg16, self).__init__()
68
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
69
+ self.slice1 = torch.nn.Sequential()
70
+ self.slice2 = torch.nn.Sequential()
71
+ self.slice3 = torch.nn.Sequential()
72
+ self.slice4 = torch.nn.Sequential()
73
+ self.slice5 = torch.nn.Sequential()
74
+ self.N_slices = 5
75
+ for x in range(4):
76
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
77
+ for x in range(4, 9):
78
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
79
+ for x in range(9, 16):
80
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
81
+ for x in range(16, 23):
82
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
83
+ for x in range(23, 30):
84
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
85
+ if not requires_grad:
86
+ for param in self.parameters():
87
+ param.requires_grad = False
88
+
89
+ def forward(self, X):
90
+ h = self.slice1(X)
91
+ h_relu1_2 = h
92
+ h = self.slice2(h)
93
+ h_relu2_2 = h
94
+ h = self.slice3(h)
95
+ h_relu3_3 = h
96
+ h = self.slice4(h)
97
+ h_relu4_3 = h
98
+ h = self.slice5(h)
99
+ h_relu5_3 = h
100
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
101
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
102
+ return out
103
+
104
+
105
+ def normalize_tensor(x,eps=1e-10):
106
+ norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
107
+ return x/(norm_factor+eps)
108
+
109
+
110
+ def spatial_average(x, keepdim=True):
111
+ return x.mean([2,3],keepdim=keepdim)
112
+
113
+
114
+ if __name__ == "__main__":
115
+ model = LPIPS().eval()
116
+ _ = torch.manual_seed(123)
117
+ img1 = (torch.rand(10, 3, 100, 100) * 2) - 1
118
+ img2 = (torch.rand(10, 3, 100, 100) * 2) - 1
119
+ print(model(img1, img2).shape)
120
+ # embed()
video_vae/modeling_resnet.py ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from diffusers.models.activations import get_activation
9
+ from diffusers.models.attention_processor import SpatialNorm
10
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
11
+ from diffusers.models.normalization import AdaGroupNorm
12
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
13
+ from .modeling_causal_conv import CausalConv3d, CausalGroupNorm
14
+
15
+
16
+ class CausalResnetBlock3D(nn.Module):
17
+ r"""
18
+ A Resnet block.
19
+
20
+ Parameters:
21
+ in_channels (`int`): The number of channels in the input.
22
+ out_channels (`int`, *optional*, default to be `None`):
23
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
24
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
25
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
26
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
27
+ groups_out (`int`, *optional*, default to None):
28
+ The number of groups to use for the second normalization layer. if set to None, same as `groups`.
29
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
30
+ non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
31
+ time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
32
+ By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
33
+ "ada_group" for a stronger conditioning with scale and shift.
34
+ kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
35
+ [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
36
+ output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
37
+ use_in_shortcut (`bool`, *optional*, default to `True`):
38
+ If `True`, add a 1x1 nn.conv2d layer for skip-connection.
39
+ up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
40
+ down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
41
+ conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
42
+ `conv_shortcut` output.
43
+ conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
44
+ If None, same as `out_channels`.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ *,
50
+ in_channels: int,
51
+ out_channels: Optional[int] = None,
52
+ conv_shortcut: bool = False,
53
+ dropout: float = 0.0,
54
+ temb_channels: int = 512,
55
+ groups: int = 32,
56
+ groups_out: Optional[int] = None,
57
+ pre_norm: bool = True,
58
+ eps: float = 1e-6,
59
+ non_linearity: str = "swish",
60
+ time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial
61
+ output_scale_factor: float = 1.0,
62
+ use_in_shortcut: Optional[bool] = None,
63
+ conv_shortcut_bias: bool = True,
64
+ conv_2d_out_channels: Optional[int] = None,
65
+ ):
66
+ super().__init__()
67
+ self.pre_norm = pre_norm
68
+ self.pre_norm = True
69
+ self.in_channels = in_channels
70
+ out_channels = in_channels if out_channels is None else out_channels
71
+ self.out_channels = out_channels
72
+ self.use_conv_shortcut = conv_shortcut
73
+ self.output_scale_factor = output_scale_factor
74
+ self.time_embedding_norm = time_embedding_norm
75
+
76
+ linear_cls = nn.Linear
77
+
78
+ if groups_out is None:
79
+ groups_out = groups
80
+
81
+ if self.time_embedding_norm == "ada_group":
82
+ self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
83
+ elif self.time_embedding_norm == "spatial":
84
+ self.norm1 = SpatialNorm(in_channels, temb_channels)
85
+ else:
86
+ self.norm1 = CausalGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
87
+
88
+ self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
89
+
90
+ if self.time_embedding_norm == "ada_group":
91
+ self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
92
+ elif self.time_embedding_norm == "spatial":
93
+ self.norm2 = SpatialNorm(out_channels, temb_channels)
94
+ else:
95
+ self.norm2 = CausalGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
96
+
97
+ self.dropout = torch.nn.Dropout(dropout)
98
+ conv_2d_out_channels = conv_2d_out_channels or out_channels
99
+ self.conv2 = CausalConv3d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1)
100
+
101
+ self.nonlinearity = get_activation(non_linearity)
102
+ self.upsample = self.downsample = None
103
+ self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
104
+
105
+ self.conv_shortcut = None
106
+ if self.use_in_shortcut:
107
+ self.conv_shortcut = CausalConv3d(
108
+ in_channels,
109
+ conv_2d_out_channels,
110
+ kernel_size=1,
111
+ stride=1,
112
+ bias=conv_shortcut_bias,
113
+ )
114
+
115
+ def forward(
116
+ self,
117
+ input_tensor: torch.FloatTensor,
118
+ temb: torch.FloatTensor = None,
119
+ is_init_image=True,
120
+ temporal_chunk=False,
121
+ ) -> torch.FloatTensor:
122
+ hidden_states = input_tensor
123
+
124
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
125
+ hidden_states = self.norm1(hidden_states, temb)
126
+ else:
127
+ hidden_states = self.norm1(hidden_states)
128
+
129
+ hidden_states = self.nonlinearity(hidden_states)
130
+
131
+ hidden_states = self.conv1(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
132
+
133
+ if temb is not None and self.time_embedding_norm == "default":
134
+ hidden_states = hidden_states + temb
135
+
136
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
137
+ hidden_states = self.norm2(hidden_states, temb)
138
+ else:
139
+ hidden_states = self.norm2(hidden_states)
140
+
141
+ hidden_states = self.nonlinearity(hidden_states)
142
+ hidden_states = self.dropout(hidden_states)
143
+ hidden_states = self.conv2(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
144
+
145
+ if self.conv_shortcut is not None:
146
+ input_tensor = self.conv_shortcut(input_tensor, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
147
+
148
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
149
+
150
+ return output_tensor
151
+
152
+
153
+ class ResnetBlock2D(nn.Module):
154
+ r"""
155
+ A Resnet block.
156
+
157
+ Parameters:
158
+ in_channels (`int`): The number of channels in the input.
159
+ out_channels (`int`, *optional*, default to be `None`):
160
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
161
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
162
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
163
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
164
+ groups_out (`int`, *optional*, default to None):
165
+ The number of groups to use for the second normalization layer. if set to None, same as `groups`.
166
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
167
+ non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
168
+ time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
169
+ By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
170
+ "ada_group" for a stronger conditioning with scale and shift.
171
+ kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
172
+ [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
173
+ output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
174
+ use_in_shortcut (`bool`, *optional*, default to `True`):
175
+ If `True`, add a 1x1 nn.conv2d layer for skip-connection.
176
+ up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
177
+ down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
178
+ conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
179
+ `conv_shortcut` output.
180
+ conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
181
+ If None, same as `out_channels`.
182
+ """
183
+
184
+ def __init__(
185
+ self,
186
+ *,
187
+ in_channels: int,
188
+ out_channels: Optional[int] = None,
189
+ conv_shortcut: bool = False,
190
+ dropout: float = 0.0,
191
+ temb_channels: int = 512,
192
+ groups: int = 32,
193
+ groups_out: Optional[int] = None,
194
+ pre_norm: bool = True,
195
+ eps: float = 1e-6,
196
+ non_linearity: str = "swish",
197
+ time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial
198
+ output_scale_factor: float = 1.0,
199
+ use_in_shortcut: Optional[bool] = None,
200
+ conv_shortcut_bias: bool = True,
201
+ conv_2d_out_channels: Optional[int] = None,
202
+ ):
203
+ super().__init__()
204
+ self.pre_norm = pre_norm
205
+ self.pre_norm = True
206
+ self.in_channels = in_channels
207
+ out_channels = in_channels if out_channels is None else out_channels
208
+ self.out_channels = out_channels
209
+ self.use_conv_shortcut = conv_shortcut
210
+ self.output_scale_factor = output_scale_factor
211
+ self.time_embedding_norm = time_embedding_norm
212
+
213
+ linear_cls = nn.Linear
214
+ conv_cls = nn.Conv3d
215
+
216
+ if groups_out is None:
217
+ groups_out = groups
218
+
219
+ if self.time_embedding_norm == "ada_group":
220
+ self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
221
+ elif self.time_embedding_norm == "spatial":
222
+ self.norm1 = SpatialNorm(in_channels, temb_channels)
223
+ else:
224
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
225
+
226
+ self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
227
+
228
+ if self.time_embedding_norm == "ada_group":
229
+ self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
230
+ elif self.time_embedding_norm == "spatial":
231
+ self.norm2 = SpatialNorm(out_channels, temb_channels)
232
+ else:
233
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
234
+
235
+ self.dropout = torch.nn.Dropout(dropout)
236
+ conv_2d_out_channels = conv_2d_out_channels or out_channels
237
+ self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
238
+
239
+ self.nonlinearity = get_activation(non_linearity)
240
+ self.upsample = self.downsample = None
241
+ self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
242
+
243
+ self.conv_shortcut = None
244
+ if self.use_in_shortcut:
245
+ self.conv_shortcut = conv_cls(
246
+ in_channels,
247
+ conv_2d_out_channels,
248
+ kernel_size=1,
249
+ stride=1,
250
+ padding=0,
251
+ bias=conv_shortcut_bias,
252
+ )
253
+
254
+ def forward(
255
+ self,
256
+ input_tensor: torch.FloatTensor,
257
+ temb: torch.FloatTensor = None,
258
+ scale: float = 1.0,
259
+ ) -> torch.FloatTensor:
260
+ hidden_states = input_tensor
261
+
262
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
263
+ hidden_states = self.norm1(hidden_states, temb)
264
+ else:
265
+ hidden_states = self.norm1(hidden_states)
266
+
267
+ hidden_states = self.nonlinearity(hidden_states)
268
+
269
+ hidden_states = self.conv1(hidden_states)
270
+
271
+ if temb is not None and self.time_embedding_norm == "default":
272
+ hidden_states = hidden_states + temb
273
+
274
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
275
+ hidden_states = self.norm2(hidden_states, temb)
276
+ else:
277
+ hidden_states = self.norm2(hidden_states)
278
+
279
+ hidden_states = self.nonlinearity(hidden_states)
280
+ hidden_states = self.dropout(hidden_states)
281
+ hidden_states = self.conv2(hidden_states)
282
+
283
+ if self.conv_shortcut is not None:
284
+ input_tensor = self.conv_shortcut(input_tensor)
285
+
286
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
287
+
288
+ return output_tensor
289
+
290
+
291
+ class CausalDownsample2x(nn.Module):
292
+ """A 2D downsampling layer with an optional convolution.
293
+
294
+ Parameters:
295
+ channels (`int`):
296
+ number of channels in the inputs and outputs.
297
+ use_conv (`bool`, default `False`):
298
+ option to use a convolution.
299
+ out_channels (`int`, optional):
300
+ number of output channels. Defaults to `channels`.
301
+ padding (`int`, default `1`):
302
+ padding for the convolution.
303
+ name (`str`, default `conv`):
304
+ name of the downsampling 2D layer.
305
+ """
306
+
307
+ def __init__(
308
+ self,
309
+ channels: int,
310
+ use_conv: bool = True,
311
+ out_channels: Optional[int] = None,
312
+ name: str = "conv",
313
+ kernel_size=3,
314
+ bias=True,
315
+ ):
316
+ super().__init__()
317
+ self.channels = channels
318
+ self.out_channels = out_channels or channels
319
+ self.use_conv = use_conv
320
+ stride = (1, 2, 2)
321
+ self.name = name
322
+
323
+ if use_conv:
324
+ conv = CausalConv3d(
325
+ self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias
326
+ )
327
+ else:
328
+ assert self.channels == self.out_channels
329
+ conv = nn.AvgPool3d(kernel_size=stride, stride=stride)
330
+
331
+ self.conv = conv
332
+
333
+ def forward(self, hidden_states: torch.FloatTensor, is_init_image=True, temporal_chunk=False) -> torch.FloatTensor:
334
+ assert hidden_states.shape[1] == self.channels
335
+ hidden_states = self.conv(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
336
+ return hidden_states
337
+
338
+
339
+ class Downsample2D(nn.Module):
340
+ """A 2D downsampling layer with an optional convolution.
341
+
342
+ Parameters:
343
+ channels (`int`):
344
+ number of channels in the inputs and outputs.
345
+ use_conv (`bool`, default `False`):
346
+ option to use a convolution.
347
+ out_channels (`int`, optional):
348
+ number of output channels. Defaults to `channels`.
349
+ padding (`int`, default `1`):
350
+ padding for the convolution.
351
+ name (`str`, default `conv`):
352
+ name of the downsampling 2D layer.
353
+ """
354
+
355
+ def __init__(
356
+ self,
357
+ channels: int,
358
+ use_conv: bool = True,
359
+ out_channels: Optional[int] = None,
360
+ padding: int = 0,
361
+ name: str = "conv",
362
+ kernel_size=3,
363
+ bias=True,
364
+ ):
365
+ super().__init__()
366
+ self.channels = channels
367
+ self.out_channels = out_channels or channels
368
+ self.use_conv = use_conv
369
+ self.padding = padding
370
+ stride = (1, 2, 2)
371
+ self.name = name
372
+ conv_cls = nn.Conv3d
373
+
374
+ if use_conv:
375
+ conv = conv_cls(
376
+ self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
377
+ )
378
+ else:
379
+ assert self.channels == self.out_channels
380
+ conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
381
+
382
+ self.conv = conv
383
+
384
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
385
+ assert hidden_states.shape[1] == self.channels
386
+
387
+ if self.use_conv and self.padding == 0:
388
+ pad = (0, 1, 0, 1, 1, 1)
389
+ hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
390
+
391
+ assert hidden_states.shape[1] == self.channels
392
+
393
+ hidden_states = self.conv(hidden_states)
394
+
395
+ return hidden_states
396
+
397
+
398
+ class TemporalDownsample2x(nn.Module):
399
+ """A Temporal downsampling layer with an optional convolution.
400
+
401
+ Parameters:
402
+ channels (`int`):
403
+ number of channels in the inputs and outputs.
404
+ use_conv (`bool`, default `False`):
405
+ option to use a convolution.
406
+ out_channels (`int`, optional):
407
+ number of output channels. Defaults to `channels`.
408
+ padding (`int`, default `1`):
409
+ padding for the convolution.
410
+ name (`str`, default `conv`):
411
+ name of the downsampling 2D layer.
412
+ """
413
+
414
+ def __init__(
415
+ self,
416
+ channels: int,
417
+ use_conv: bool = False,
418
+ out_channels: Optional[int] = None,
419
+ padding: int = 0,
420
+ kernel_size=3,
421
+ bias=True,
422
+ ):
423
+ super().__init__()
424
+ self.channels = channels
425
+ self.out_channels = out_channels or channels
426
+ self.use_conv = use_conv
427
+ self.padding = padding
428
+ stride = (2, 1, 1)
429
+
430
+ conv_cls = nn.Conv3d
431
+
432
+ if use_conv:
433
+ conv = conv_cls(
434
+ self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
435
+ )
436
+ else:
437
+ raise NotImplementedError("Not implemented for temporal downsample without")
438
+
439
+ self.conv = conv
440
+
441
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
442
+ assert hidden_states.shape[1] == self.channels
443
+
444
+ if self.use_conv and self.padding == 0:
445
+ if hidden_states.shape[2] == 1:
446
+ # image
447
+ pad = (1, 1, 1, 1, 1, 1)
448
+ else:
449
+ # video
450
+ pad = (1, 1, 1, 1, 0, 1)
451
+
452
+ hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
453
+
454
+ hidden_states = self.conv(hidden_states)
455
+ return hidden_states
456
+
457
+
458
+ class CausalTemporalDownsample2x(nn.Module):
459
+ """A Temporal downsampling layer with an optional convolution.
460
+
461
+ Parameters:
462
+ channels (`int`):
463
+ number of channels in the inputs and outputs.
464
+ use_conv (`bool`, default `False`):
465
+ option to use a convolution.
466
+ out_channels (`int`, optional):
467
+ number of output channels. Defaults to `channels`.
468
+ padding (`int`, default `1`):
469
+ padding for the convolution.
470
+ name (`str`, default `conv`):
471
+ name of the downsampling 2D layer.
472
+ """
473
+
474
+ def __init__(
475
+ self,
476
+ channels: int,
477
+ use_conv: bool = False,
478
+ out_channels: Optional[int] = None,
479
+ kernel_size=3,
480
+ bias=True,
481
+ ):
482
+ super().__init__()
483
+ self.channels = channels
484
+ self.out_channels = out_channels or channels
485
+ self.use_conv = use_conv
486
+ stride = (2, 1, 1)
487
+
488
+ conv_cls = nn.Conv3d
489
+
490
+ if use_conv:
491
+ conv = CausalConv3d(
492
+ self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias
493
+ )
494
+ else:
495
+ raise NotImplementedError("Not implemented for temporal downsample without")
496
+
497
+ self.conv = conv
498
+
499
+ def forward(self, hidden_states: torch.FloatTensor, is_init_image=True, temporal_chunk=False) -> torch.FloatTensor:
500
+ assert hidden_states.shape[1] == self.channels
501
+ hidden_states = self.conv(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
502
+ return hidden_states
503
+
504
+
505
+ class Upsample2D(nn.Module):
506
+ """A 2D upsampling layer with an optional convolution.
507
+
508
+ Parameters:
509
+ channels (`int`):
510
+ number of channels in the inputs and outputs.
511
+ use_conv (`bool`, default `False`):
512
+ option to use a convolution.
513
+ out_channels (`int`, optional):
514
+ number of output channels. Defaults to `channels`.
515
+ name (`str`, default `conv`):
516
+ name of the upsampling 2D layer.
517
+ """
518
+
519
+ def __init__(
520
+ self,
521
+ channels: int,
522
+ use_conv: bool = False,
523
+ out_channels: Optional[int] = None,
524
+ name: str = "conv",
525
+ kernel_size: Optional[int] = None,
526
+ padding=1,
527
+ bias=True,
528
+ interpolate=False,
529
+ ):
530
+ super().__init__()
531
+ self.channels = channels
532
+ self.out_channels = out_channels or channels
533
+ self.use_conv = use_conv
534
+ self.name = name
535
+ self.interpolate = interpolate
536
+ conv_cls = nn.Conv3d
537
+ conv = None
538
+
539
+ if interpolate:
540
+ raise NotImplementedError("Not implemented for spatial upsample with interpolate")
541
+ else:
542
+ if kernel_size is None:
543
+ kernel_size = 3
544
+ conv = conv_cls(self.channels, self.out_channels * 4, kernel_size=kernel_size, padding=padding, bias=bias)
545
+
546
+ self.conv = conv
547
+ self.conv.apply(self._init_weights)
548
+
549
+ def _init_weights(self, m):
550
+ if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
551
+ trunc_normal_(m.weight, std=.02)
552
+ if m.bias is not None:
553
+ nn.init.constant_(m.bias, 0)
554
+ elif isinstance(m, nn.LayerNorm):
555
+ nn.init.constant_(m.bias, 0)
556
+ nn.init.constant_(m.weight, 1.0)
557
+
558
+ def forward(
559
+ self,
560
+ hidden_states: torch.FloatTensor,
561
+ ) -> torch.FloatTensor:
562
+ assert hidden_states.shape[1] == self.channels
563
+
564
+ hidden_states = self.conv(hidden_states)
565
+ hidden_states = rearrange(hidden_states, 'b (c p1 p2) t h w -> b c t (h p1) (w p2)', p1=2, p2=2)
566
+
567
+ return hidden_states
568
+
569
+
570
+ class CausalUpsample2x(nn.Module):
571
+ """A 2D upsampling layer with an optional convolution.
572
+
573
+ Parameters:
574
+ channels (`int`):
575
+ number of channels in the inputs and outputs.
576
+ use_conv (`bool`, default `False`):
577
+ option to use a convolution.
578
+ out_channels (`int`, optional):
579
+ number of output channels. Defaults to `channels`.
580
+ name (`str`, default `conv`):
581
+ name of the upsampling 2D layer.
582
+ """
583
+
584
+ def __init__(
585
+ self,
586
+ channels: int,
587
+ use_conv: bool = False,
588
+ out_channels: Optional[int] = None,
589
+ name: str = "conv",
590
+ kernel_size: Optional[int] = 3,
591
+ bias=True,
592
+ interpolate=False,
593
+ ):
594
+ super().__init__()
595
+ self.channels = channels
596
+ self.out_channels = out_channels or channels
597
+ self.use_conv = use_conv
598
+ self.name = name
599
+ self.interpolate = interpolate
600
+ conv = None
601
+
602
+ if interpolate:
603
+ raise NotImplementedError("Not implemented for spatial upsample with interpolate")
604
+ else:
605
+ conv = CausalConv3d(self.channels, self.out_channels * 4, kernel_size=kernel_size, stride=1, bias=bias)
606
+
607
+ self.conv = conv
608
+
609
+ def forward(
610
+ self,
611
+ hidden_states: torch.FloatTensor,
612
+ is_init_image=True, temporal_chunk=False,
613
+ ) -> torch.FloatTensor:
614
+ assert hidden_states.shape[1] == self.channels
615
+ hidden_states = self.conv(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
616
+ hidden_states = rearrange(hidden_states, 'b (c p1 p2) t h w -> b c t (h p1) (w p2)', p1=2, p2=2)
617
+ return hidden_states
618
+
619
+
620
+ class TemporalUpsample2x(nn.Module):
621
+ """A 2D upsampling layer with an optional convolution.
622
+
623
+ Parameters:
624
+ channels (`int`):
625
+ number of channels in the inputs and outputs.
626
+ use_conv (`bool`, default `False`):
627
+ option to use a convolution.
628
+ out_channels (`int`, optional):
629
+ number of output channels. Defaults to `channels`.
630
+ name (`str`, default `conv`):
631
+ name of the upsampling 2D layer.
632
+ """
633
+
634
+ def __init__(
635
+ self,
636
+ channels: int,
637
+ use_conv: bool = True,
638
+ out_channels: Optional[int] = None,
639
+ kernel_size: Optional[int] = None,
640
+ padding=1,
641
+ bias=True,
642
+ interpolate=False,
643
+ ):
644
+ super().__init__()
645
+ self.channels = channels
646
+ self.out_channels = out_channels or channels
647
+ self.use_conv = use_conv
648
+ self.interpolate = interpolate
649
+ conv_cls = nn.Conv3d
650
+
651
+ conv = None
652
+ if interpolate:
653
+ raise NotImplementedError("Not implemented for spatial upsample with interpolate")
654
+ else:
655
+ # depth to space operator
656
+ if kernel_size is None:
657
+ kernel_size = 3
658
+ conv = conv_cls(self.channels, self.out_channels * 2, kernel_size=kernel_size, padding=padding, bias=bias)
659
+
660
+ self.conv = conv
661
+
662
+ def forward(
663
+ self,
664
+ hidden_states: torch.FloatTensor,
665
+ is_image: bool = False,
666
+ ) -> torch.FloatTensor:
667
+ assert hidden_states.shape[1] == self.channels
668
+ t = hidden_states.shape[2]
669
+ hidden_states = self.conv(hidden_states)
670
+ hidden_states = rearrange(hidden_states, 'b (c p) t h w -> b c (p t) h w', p=2)
671
+
672
+ if t == 1 and is_image:
673
+ hidden_states = hidden_states[:, :, 1:]
674
+
675
+ return hidden_states
676
+
677
+
678
+ class CausalTemporalUpsample2x(nn.Module):
679
+ """A 2D upsampling layer with an optional convolution.
680
+
681
+ Parameters:
682
+ channels (`int`):
683
+ number of channels in the inputs and outputs.
684
+ use_conv (`bool`, default `False`):
685
+ option to use a convolution.
686
+ out_channels (`int`, optional):
687
+ number of output channels. Defaults to `channels`.
688
+ name (`str`, default `conv`):
689
+ name of the upsampling 2D layer.
690
+ """
691
+
692
+ def __init__(
693
+ self,
694
+ channels: int,
695
+ use_conv: bool = True,
696
+ out_channels: Optional[int] = None,
697
+ kernel_size: Optional[int] = 3,
698
+ bias=True,
699
+ interpolate=False,
700
+ ):
701
+ super().__init__()
702
+ self.channels = channels
703
+ self.out_channels = out_channels or channels
704
+ self.use_conv = use_conv
705
+ self.interpolate = interpolate
706
+
707
+ conv = None
708
+ if interpolate:
709
+ raise NotImplementedError("Not implemented for spatial upsample with interpolate")
710
+ else:
711
+ # depth to space operator
712
+ conv = CausalConv3d(self.channels, self.out_channels * 2, kernel_size=kernel_size, stride=1, bias=bias)
713
+
714
+ self.conv = conv
715
+
716
+ def forward(
717
+ self,
718
+ hidden_states: torch.FloatTensor,
719
+ is_init_image=True, temporal_chunk=False,
720
+ ) -> torch.FloatTensor:
721
+ assert hidden_states.shape[1] == self.channels
722
+ t = hidden_states.shape[2]
723
+ hidden_states = self.conv(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
724
+ hidden_states = rearrange(hidden_states, 'b (c p) t h w -> b c (t p) h w', p=2)
725
+
726
+ if is_init_image:
727
+ hidden_states = hidden_states[:, :, 1:]
728
+
729
+ return hidden_states