MaxMilan1 commited on
Commit
09339b5
·
1 Parent(s): 63f29cf

possible working changes for V3D?

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +60 -1
  2. requirements.txt +44 -7
  3. scripts/__init__.py +0 -0
  4. scripts/pub/V3D_512.py +317 -0
  5. scripts/pub/configs/V3D_512.yaml +161 -0
  6. scripts/tests/attention.py +319 -0
  7. scripts/util/__init__.py +0 -0
  8. scripts/util/detection/__init__.py +0 -0
  9. scripts/util/detection/nsfw_and_watermark_dectection.py +110 -0
  10. scripts/util/detection/p_head_v1.npz +3 -0
  11. scripts/util/detection/w_head_v1.npz +3 -0
  12. sgm/__init__.py +4 -0
  13. sgm/data/__init__.py +1 -0
  14. sgm/data/cam_utils.py +1253 -0
  15. sgm/data/cifar10.py +67 -0
  16. sgm/data/co3d.py +1367 -0
  17. sgm/data/colmap.py +605 -0
  18. sgm/data/dataset.py +80 -0
  19. sgm/data/joint3d.py +10 -0
  20. sgm/data/json_index_dataset.py +1080 -0
  21. sgm/data/latent_objaverse.py +52 -0
  22. sgm/data/mnist.py +85 -0
  23. sgm/data/mvimagenet.py +408 -0
  24. sgm/data/objaverse.py +882 -0
  25. sgm/inference/api.py +385 -0
  26. sgm/inference/helpers.py +305 -0
  27. sgm/lr_scheduler.py +135 -0
  28. sgm/models/__init__.py +2 -0
  29. sgm/models/autoencoder.py +615 -0
  30. sgm/models/diffusion.py +358 -0
  31. sgm/models/video3d_diffusion.py +524 -0
  32. sgm/models/video_diffusion.py +503 -0
  33. sgm/modules/__init__.py +6 -0
  34. sgm/modules/attention.py +764 -0
  35. sgm/modules/autoencoding/__init__.py +0 -0
  36. sgm/modules/autoencoding/losses/__init__.py +7 -0
  37. sgm/modules/autoencoding/losses/discriminator_loss.py +306 -0
  38. sgm/modules/autoencoding/losses/lpips.py +73 -0
  39. sgm/modules/autoencoding/lpips/__init__.py +0 -0
  40. sgm/modules/autoencoding/lpips/loss/.gitignore +1 -0
  41. sgm/modules/autoencoding/lpips/loss/LICENSE +23 -0
  42. sgm/modules/autoencoding/lpips/loss/__init__.py +0 -0
  43. sgm/modules/autoencoding/lpips/loss/lpips.py +147 -0
  44. sgm/modules/autoencoding/lpips/model/LICENSE +58 -0
  45. sgm/modules/autoencoding/lpips/model/__init__.py +0 -0
  46. sgm/modules/autoencoding/lpips/model/model.py +88 -0
  47. sgm/modules/autoencoding/lpips/util.py +128 -0
  48. sgm/modules/autoencoding/lpips/vqperceptual.py +17 -0
  49. sgm/modules/autoencoding/regularizers/__init__.py +31 -0
  50. sgm/modules/autoencoding/regularizers/base.py +40 -0
app.py CHANGED
@@ -1,5 +1,9 @@
1
  import gradio as gr
2
  from util.text_img import generate_image
 
 
 
 
3
 
4
  _TITLE = "Shoe Generator"
5
  with gr.Blocks(_TITLE) as ShoeGen:
@@ -18,6 +22,61 @@ with gr.Blocks(_TITLE) as ShoeGen:
18
  button_gen.click(generate_image, inputs=[prompt], outputs=[image, image_nobg])
19
 
20
  with gr.Tab("Image to Video Generator (V3D)"):
21
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  ShoeGen.launch()
 
1
  import gradio as gr
2
  from util.text_img import generate_image
3
+ from util.v3d import generate_v3d, prep
4
+
5
+ # Prepare the V3D model
6
+ model, clip_model, ae_model, device, num_frames, num_steps, rembg_session, output_folder = prep()
7
 
8
  _TITLE = "Shoe Generator"
9
  with gr.Blocks(_TITLE) as ShoeGen:
 
22
  button_gen.click(generate_image, inputs=[prompt], outputs=[image, image_nobg])
23
 
24
  with gr.Tab("Image to Video Generator (V3D)"):
25
+ with gr.Row(equal_height=True):
26
+ with gr.Column():
27
+ input_image = gr.Image(value=None, label="Input Image")
28
+
29
+ border_ratio_slider = gr.Slider(
30
+ value=0.3,
31
+ label="Border Ratio",
32
+ minimum=0.05,
33
+ maximum=0.5,
34
+ step=0.05,
35
+ )
36
+ decoding_t_slider = gr.Slider(
37
+ value=1,
38
+ label="Number of Decoding frames",
39
+ minimum=1,
40
+ maximum=num_frames,
41
+ step=1,
42
+ )
43
+ min_guidance_slider = gr.Slider(
44
+ value=3.5,
45
+ label="Min CFG Value",
46
+ minimum=0.05,
47
+ maximum=0.5,
48
+ step=0.05,
49
+ )
50
+ max_guidance_slider = gr.Slider(
51
+ value=3.5,
52
+ label="Max CFG Value",
53
+ minimum=0.05,
54
+ maximum=0.5,
55
+ step=0.05,
56
+ )
57
+ run_button = gr.Button(value="Run V3D")
58
+
59
+ with gr.Column():
60
+ output_video = gr.Video(value=None, label="Output Orbit Video")
61
+
62
+ run_button.click(generate_v3d,
63
+ inputs=[
64
+ input_image,
65
+ model,
66
+ clip_model,
67
+ ae_model,
68
+ num_frames,
69
+ num_steps,
70
+ int(decoding_t_slider),
71
+ border_ratio_slider,
72
+ False,
73
+ rembg_session,
74
+ output_folder,
75
+ min_guidance_slider,
76
+ max_guidance_slider,
77
+ device,
78
+ ],
79
+ outputs=[output_video],
80
+ )
81
 
82
  ShoeGen.launch()
requirements.txt CHANGED
@@ -1,12 +1,49 @@
1
- torch
2
  gradio
3
  diffusers==0.26.3
4
- transformers==4.38.1
5
  accelerate==0.27.2
6
- xformers
7
  rembg
8
- Pillow
9
  Python-IO
10
- numpy
11
- opencv-python
12
- huggingface-hub
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  gradio
2
  diffusers==0.26.3
 
3
  accelerate==0.27.2
 
4
  rembg
 
5
  Python-IO
6
+ huggingface-hub
7
+ black==23.7.0
8
+ chardet==5.1.0
9
+ clip @ git+https://github.com/openai/CLIP.git
10
+ einops>=0.6.1
11
+ fairscale>=0.4.13
12
+ fire>=0.5.0
13
+ fsspec>=2023.6.0
14
+ invisible-watermark>=0.2.0
15
+ kornia==0.6.9
16
+ matplotlib>=3.7.2
17
+ natsort>=8.4.0
18
+ ninja>=1.11.1
19
+ numpy>=1.24.4
20
+ omegaconf>=2.3.0
21
+ open-clip-torch>=2.20.0
22
+ opencv-python==4.6.0.66
23
+ pandas>=2.0.3
24
+ pillow>=9.5.0
25
+ pudb>=2022.1.3
26
+ pytorch-lightning==2.0.1
27
+ pyyaml>=6.0.1
28
+ scipy>=1.10.1
29
+ streamlit>=0.73.1
30
+ tensorboardx==2.6
31
+ timm>=0.9.2
32
+ tokenizers==0.12.1
33
+ torch>=2.0.1
34
+ torchaudio>=2.0.2
35
+ torchdata==0.6.1
36
+ torchmetrics>=1.0.1
37
+ torchvision>=0.15.2
38
+ tqdm>=4.65.0
39
+ transformers==4.19.1
40
+ triton==2.0.0
41
+ urllib3<1.27,>=1.25.4
42
+ wandb>=0.15.6
43
+ webdataset>=0.2.33
44
+ wheel>=0.41.0
45
+ xformers>=0.0.20
46
+ streamlit-keyup==0.2.0
47
+ mediapy
48
+ tyro
49
+ wget
scripts/__init__.py ADDED
File without changes
scripts/pub/V3D_512.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from glob import glob
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import torch
10
+ from einops import rearrange, repeat
11
+ from fire import Fire
12
+ import tyro
13
+ from omegaconf import OmegaConf
14
+ from PIL import Image
15
+ from torchvision.transforms import ToTensor
16
+ from mediapy import write_video
17
+ import rembg
18
+ from kiui.op import recenter
19
+ from safetensors.torch import load_file as load_safetensors
20
+ from typing import Any
21
+
22
+ from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
23
+ from sgm.inference.helpers import embed_watermark
24
+ from sgm.util import default, instantiate_from_config
25
+
26
+
27
+ def get_unique_embedder_keys_from_conditioner(conditioner):
28
+ return list(set([x.input_key for x in conditioner.embedders]))
29
+
30
+
31
+ def get_batch(keys, value_dict, N, T, device):
32
+ batch = {}
33
+ batch_uc = {}
34
+
35
+ for key in keys:
36
+ if key == "fps_id":
37
+ batch[key] = (
38
+ torch.tensor([value_dict["fps_id"]])
39
+ .to(device)
40
+ .repeat(int(math.prod(N)))
41
+ )
42
+ elif key == "motion_bucket_id":
43
+ batch[key] = (
44
+ torch.tensor([value_dict["motion_bucket_id"]])
45
+ .to(device)
46
+ .repeat(int(math.prod(N)))
47
+ )
48
+ elif key == "cond_aug":
49
+ batch[key] = repeat(
50
+ torch.tensor([value_dict["cond_aug"]]).to(device),
51
+ "1 -> b",
52
+ b=math.prod(N),
53
+ )
54
+ elif key == "cond_frames":
55
+ batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
56
+ elif key == "cond_frames_without_noise":
57
+ batch[key] = repeat(
58
+ value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
59
+ )
60
+ else:
61
+ batch[key] = value_dict[key]
62
+
63
+ if T is not None:
64
+ batch["num_video_frames"] = T
65
+
66
+ for key in batch.keys():
67
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
68
+ batch_uc[key] = torch.clone(batch[key])
69
+ return batch, batch_uc
70
+
71
+
72
+ def load_model(
73
+ config: str,
74
+ device: str,
75
+ num_frames: int,
76
+ num_steps: int,
77
+ ckpt_path: Optional[str] = None,
78
+ min_cfg: Optional[float] = None,
79
+ max_cfg: Optional[float] = None,
80
+ sigma_max: Optional[float] = None,
81
+ ):
82
+ config = OmegaConf.load(config)
83
+
84
+ config.model.params.sampler_config.params.num_steps = num_steps
85
+ config.model.params.sampler_config.params.guider_config.params.num_frames = (
86
+ num_frames
87
+ )
88
+ if max_cfg is not None:
89
+ config.model.params.sampler_config.params.guider_config.params.max_scale = (
90
+ max_cfg
91
+ )
92
+ if min_cfg is not None:
93
+ config.model.params.sampler_config.params.guider_config.params.min_scale = (
94
+ min_cfg
95
+ )
96
+ if sigma_max is not None:
97
+ print("Overriding sigma_max to ", sigma_max)
98
+ config.model.params.sampler_config.params.discretization_config.params.sigma_max = (
99
+ sigma_max
100
+ )
101
+
102
+ config.model.params.from_scratch = False
103
+
104
+ if ckpt_path is not None:
105
+ config.model.params.ckpt_path = str(ckpt_path)
106
+ if device == "cuda":
107
+ with torch.device(device):
108
+ model = instantiate_from_config(config.model).to(device).eval()
109
+ else:
110
+ model = instantiate_from_config(config.model).to(device).eval()
111
+
112
+ return model, None
113
+
114
+
115
+ def sample_one(
116
+ input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
117
+ checkpoint_path: Optional[str] = None,
118
+ num_frames: Optional[int] = None,
119
+ num_steps: Optional[int] = None,
120
+ fps_id: int = 1,
121
+ motion_bucket_id: int = 300,
122
+ cond_aug: float = 0.02,
123
+ seed: int = 23,
124
+ decoding_t: int = 24, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
125
+ device: str = "cuda",
126
+ output_folder: Optional[str] = None,
127
+ noise: torch.Tensor = None,
128
+ save: bool = False,
129
+ cached_model: Any = None,
130
+ border_ratio: float = 0.3,
131
+ min_guidance_scale: float = 3.5,
132
+ max_guidance_scale: float = 3.5,
133
+ sigma_max: float = None,
134
+ ignore_alpha: bool = False,
135
+ ):
136
+ model_config = "scripts/pub/configs/V3D_512.yaml"
137
+ num_frames = OmegaConf.load(
138
+ model_config
139
+ ).model.params.sampler_config.params.guider_config.params.num_frames
140
+ print("Detected num_frames:", num_frames)
141
+ num_steps = default(num_steps, 25)
142
+ output_folder = default(output_folder, f"outputs/V3D_512")
143
+ decoding_t = min(decoding_t, num_frames)
144
+
145
+ sd = load_safetensors("./ckpts/svd_xt.safetensors")
146
+ clip_model_config = OmegaConf.load("configs/embedder/clip_image.yaml")
147
+ clip_model = instantiate_from_config(clip_model_config).eval()
148
+ clip_sd = dict()
149
+ for k, v in sd.items():
150
+ if "conditioner.embedders.0" in k:
151
+ clip_sd[k.replace("conditioner.embedders.0.", "")] = v
152
+ clip_model.load_state_dict(clip_sd)
153
+ clip_model = clip_model.to(device)
154
+
155
+ ae_model_config = OmegaConf.load("configs/ae/video.yaml")
156
+ ae_model = instantiate_from_config(ae_model_config).eval()
157
+ encoder_sd = dict()
158
+ for k, v in sd.items():
159
+ if "first_stage_model" in k:
160
+ encoder_sd[k.replace("first_stage_model.", "")] = v
161
+ ae_model.load_state_dict(encoder_sd)
162
+ ae_model = ae_model.to(device)
163
+
164
+ if cached_model is None:
165
+ model, filter = load_model(
166
+ model_config,
167
+ device,
168
+ num_frames,
169
+ num_steps,
170
+ ckpt_path=checkpoint_path,
171
+ min_cfg=min_guidance_scale,
172
+ max_cfg=max_guidance_scale,
173
+ sigma_max=sigma_max,
174
+ )
175
+ else:
176
+ model = cached_model
177
+ torch.manual_seed(seed)
178
+
179
+ need_return = True
180
+ path = Path(input_path)
181
+ if path.is_file():
182
+ if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
183
+ all_img_paths = [input_path]
184
+ else:
185
+ raise ValueError("Path is not valid image file.")
186
+ elif path.is_dir():
187
+ all_img_paths = sorted(
188
+ [
189
+ f
190
+ for f in path.iterdir()
191
+ if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
192
+ ]
193
+ )
194
+ need_return = False
195
+ if len(all_img_paths) == 0:
196
+ raise ValueError("Folder does not contain any images.")
197
+ else:
198
+ raise ValueError
199
+
200
+ for input_path in all_img_paths:
201
+ with Image.open(input_path) as image:
202
+ # if image.mode == "RGBA":
203
+ # image = image.convert("RGB")
204
+ w, h = image.size
205
+
206
+ if border_ratio > 0:
207
+ if image.mode != "RGBA" or ignore_alpha:
208
+ image = image.convert("RGB")
209
+ image = np.asarray(image)
210
+ carved_image = rembg.remove(image) # [H, W, 4]
211
+ else:
212
+ image = np.asarray(image)
213
+ carved_image = image
214
+ mask = carved_image[..., -1] > 0
215
+ image = recenter(carved_image, mask, border_ratio=border_ratio)
216
+ image = image.astype(np.float32) / 255.0
217
+ if image.shape[-1] == 4:
218
+ image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
219
+ image = Image.fromarray((image * 255).astype(np.uint8))
220
+ else:
221
+ print("Ignore border ratio")
222
+ image = image.resize((512, 512))
223
+
224
+ image = ToTensor()(image)
225
+ image = image * 2.0 - 1.0
226
+
227
+ image = image.unsqueeze(0).to(device)
228
+ H, W = image.shape[2:]
229
+ assert image.shape[1] == 3
230
+ F = 8
231
+ C = 4
232
+ shape = (num_frames, C, H // F, W // F)
233
+
234
+ value_dict = {}
235
+ value_dict["motion_bucket_id"] = motion_bucket_id
236
+ value_dict["fps_id"] = fps_id
237
+ value_dict["cond_aug"] = cond_aug
238
+ value_dict["cond_frames_without_noise"] = clip_model(image)
239
+ value_dict["cond_frames"] = ae_model.encode(image)
240
+ value_dict["cond_frames"] += cond_aug * torch.randn_like(
241
+ value_dict["cond_frames"]
242
+ )
243
+ value_dict["cond_aug"] = cond_aug
244
+
245
+ with torch.no_grad():
246
+ with torch.autocast(device):
247
+ batch, batch_uc = get_batch(
248
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
249
+ value_dict,
250
+ [1, num_frames],
251
+ T=num_frames,
252
+ device=device,
253
+ )
254
+ c, uc = model.conditioner.get_unconditional_conditioning(
255
+ batch,
256
+ batch_uc=batch_uc,
257
+ force_uc_zero_embeddings=[
258
+ "cond_frames",
259
+ "cond_frames_without_noise",
260
+ ],
261
+ )
262
+
263
+ for k in ["crossattn", "concat"]:
264
+ uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
265
+ uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
266
+ c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
267
+ c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
268
+
269
+ randn = torch.randn(shape, device=device) if noise is None else noise
270
+ randn = randn.to(device)
271
+
272
+ additional_model_inputs = {}
273
+ additional_model_inputs["image_only_indicator"] = torch.zeros(
274
+ 2, num_frames
275
+ ).to(device)
276
+ additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
277
+
278
+ def denoiser(input, sigma, c):
279
+ return model.denoiser(
280
+ model.model, input, sigma, c, **additional_model_inputs
281
+ )
282
+
283
+ samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
284
+ model.en_and_decode_n_samples_a_time = decoding_t
285
+ samples_x = model.decode_first_stage(samples_z)
286
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
287
+
288
+ os.makedirs(output_folder, exist_ok=True)
289
+ base_count = len(glob(os.path.join(output_folder, "*.mp4")))
290
+ video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
291
+ # writer = cv2.VideoWriter(
292
+ # video_path,
293
+ # cv2.VideoWriter_fourcc(*"MP4V"),
294
+ # fps_id + 1,
295
+ # (samples.shape[-1], samples.shape[-2]),
296
+ # )
297
+
298
+ frames = (
299
+ (rearrange(samples, "t c h w -> t h w c") * 255)
300
+ .cpu()
301
+ .numpy()
302
+ .astype(np.uint8)
303
+ )
304
+
305
+ if save:
306
+ write_video(video_path, frames, fps=3)
307
+
308
+ images = []
309
+ for frame in frames:
310
+ images.append(Image.fromarray(frame))
311
+
312
+ if need_return:
313
+ return images, model
314
+
315
+
316
+ if __name__ == "__main__":
317
+ tyro.cli(sample_one)
scripts/pub/configs/V3D_512.yaml ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: sgm.models.video_diffusion.DiffusionEngine
4
+ params:
5
+ ckpt_path: ckpts/V3D_512.ckpt
6
+ scale_factor: 0.18215
7
+ disable_first_stage_autocast: true
8
+ input_key: latents
9
+ log_keys: []
10
+ scheduler_config:
11
+ target: sgm.lr_scheduler.LambdaLinearScheduler
12
+ params:
13
+ warm_up_steps:
14
+ - 1
15
+ cycle_lengths:
16
+ - 10000000000000
17
+ f_start:
18
+ - 1.0e-06
19
+ f_max:
20
+ - 1.0
21
+ f_min:
22
+ - 1.0
23
+ denoiser_config:
24
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
25
+ params:
26
+ scaling_config:
27
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
28
+ network_config:
29
+ target: sgm.modules.diffusionmodules.video_model.VideoUNet
30
+ params:
31
+ adm_in_channels: 768
32
+ num_classes: sequential
33
+ use_checkpoint: true
34
+ in_channels: 8
35
+ out_channels: 4
36
+ model_channels: 320
37
+ attention_resolutions:
38
+ - 4
39
+ - 2
40
+ - 1
41
+ num_res_blocks: 2
42
+ channel_mult:
43
+ - 1
44
+ - 2
45
+ - 4
46
+ - 4
47
+ num_head_channels: 64
48
+ use_linear_in_transformer: true
49
+ transformer_depth: 1
50
+ context_dim: 1024
51
+ spatial_transformer_attn_type: softmax-xformers
52
+ extra_ff_mix_layer: true
53
+ use_spatial_context: true
54
+ merge_strategy: learned_with_images
55
+ video_kernel_size:
56
+ - 3
57
+ - 1
58
+ - 1
59
+ conditioner_config:
60
+ target: sgm.modules.GeneralConditioner
61
+ params:
62
+ emb_models:
63
+ - is_trainable: false
64
+ ucg_rate: 0.2
65
+ input_key: cond_frames_without_noise
66
+ target: sgm.modules.encoders.modules.IdentityEncoder
67
+ - input_key: fps_id
68
+ is_trainable: true
69
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
70
+ params:
71
+ outdim: 256
72
+ - input_key: motion_bucket_id
73
+ is_trainable: true
74
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
75
+ params:
76
+ outdim: 256
77
+ - input_key: cond_frames
78
+ is_trainable: false
79
+ ucg_rate: 0.2
80
+ target: sgm.modules.encoders.modules.IdentityEncoder
81
+ - input_key: cond_aug
82
+ is_trainable: true
83
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
84
+ params:
85
+ outdim: 256
86
+ first_stage_config:
87
+ target: sgm.models.autoencoder.AutoencodingEngine
88
+ params:
89
+ loss_config:
90
+ target: torch.nn.Identity
91
+ regularizer_config:
92
+ target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
93
+ encoder_config:
94
+ target: sgm.modules.diffusionmodules.model.Encoder
95
+ params:
96
+ attn_type: vanilla
97
+ double_z: true
98
+ z_channels: 4
99
+ resolution: 256
100
+ in_channels: 3
101
+ out_ch: 3
102
+ ch: 128
103
+ ch_mult:
104
+ - 1
105
+ - 2
106
+ - 4
107
+ - 4
108
+ num_res_blocks: 2
109
+ attn_resolutions: []
110
+ dropout: 0.0
111
+ decoder_config:
112
+ target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
113
+ params:
114
+ attn_type: vanilla
115
+ double_z: true
116
+ z_channels: 4
117
+ resolution: 256
118
+ in_channels: 3
119
+ out_ch: 3
120
+ ch: 128
121
+ ch_mult:
122
+ - 1
123
+ - 2
124
+ - 4
125
+ - 4
126
+ num_res_blocks: 2
127
+ attn_resolutions: []
128
+ dropout: 0.0
129
+ video_kernel_size:
130
+ - 3
131
+ - 1
132
+ - 1
133
+ sampler_config:
134
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
135
+ params:
136
+ num_steps: 30
137
+ discretization_config:
138
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
139
+ params:
140
+ sigma_max: 700.0
141
+ guider_config:
142
+ target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
143
+ params:
144
+ max_scale: 3.5
145
+ min_scale: 3.5
146
+ num_frames: 18
147
+ loss_fn_config:
148
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
149
+ params:
150
+ batch2model_keys:
151
+ - num_video_frames
152
+ - image_only_indicator
153
+ loss_weighting_config:
154
+ target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
155
+ params:
156
+ sigma_data: 1.0
157
+ sigma_sampler_config:
158
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
159
+ params:
160
+ p_mean: 1.5
161
+ p_std: 2.0
scripts/tests/attention.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import einops
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torch.utils.benchmark as benchmark
5
+ from torch.backends.cuda import SDPBackend
6
+
7
+ from sgm.modules.attention import BasicTransformerBlock, SpatialTransformer
8
+
9
+
10
+ def benchmark_attn():
11
+ # Lets define a helpful benchmarking function:
12
+ # https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
16
+ t0 = benchmark.Timer(
17
+ stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
18
+ )
19
+ return t0.blocked_autorange().mean * 1e6
20
+
21
+ # Lets define the hyper-parameters of our input
22
+ batch_size = 32
23
+ max_sequence_len = 1024
24
+ num_heads = 32
25
+ embed_dimension = 32
26
+
27
+ dtype = torch.float16
28
+
29
+ query = torch.rand(
30
+ batch_size,
31
+ num_heads,
32
+ max_sequence_len,
33
+ embed_dimension,
34
+ device=device,
35
+ dtype=dtype,
36
+ )
37
+ key = torch.rand(
38
+ batch_size,
39
+ num_heads,
40
+ max_sequence_len,
41
+ embed_dimension,
42
+ device=device,
43
+ dtype=dtype,
44
+ )
45
+ value = torch.rand(
46
+ batch_size,
47
+ num_heads,
48
+ max_sequence_len,
49
+ embed_dimension,
50
+ device=device,
51
+ dtype=dtype,
52
+ )
53
+
54
+ print(f"q/k/v shape:", query.shape, key.shape, value.shape)
55
+
56
+ # Lets explore the speed of each of the 3 implementations
57
+ from torch.backends.cuda import SDPBackend, sdp_kernel
58
+
59
+ # Helpful arguments mapper
60
+ backend_map = {
61
+ SDPBackend.MATH: {
62
+ "enable_math": True,
63
+ "enable_flash": False,
64
+ "enable_mem_efficient": False,
65
+ },
66
+ SDPBackend.FLASH_ATTENTION: {
67
+ "enable_math": False,
68
+ "enable_flash": True,
69
+ "enable_mem_efficient": False,
70
+ },
71
+ SDPBackend.EFFICIENT_ATTENTION: {
72
+ "enable_math": False,
73
+ "enable_flash": False,
74
+ "enable_mem_efficient": True,
75
+ },
76
+ }
77
+
78
+ from torch.profiler import ProfilerActivity, profile, record_function
79
+
80
+ activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
81
+
82
+ print(
83
+ f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
84
+ )
85
+ with profile(
86
+ activities=activities, record_shapes=False, profile_memory=True
87
+ ) as prof:
88
+ with record_function("Default detailed stats"):
89
+ for _ in range(25):
90
+ o = F.scaled_dot_product_attention(query, key, value)
91
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
92
+
93
+ print(
94
+ f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
95
+ )
96
+ with sdp_kernel(**backend_map[SDPBackend.MATH]):
97
+ with profile(
98
+ activities=activities, record_shapes=False, profile_memory=True
99
+ ) as prof:
100
+ with record_function("Math implmentation stats"):
101
+ for _ in range(25):
102
+ o = F.scaled_dot_product_attention(query, key, value)
103
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
104
+
105
+ with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
106
+ try:
107
+ print(
108
+ f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
109
+ )
110
+ except RuntimeError:
111
+ print("FlashAttention is not supported. See warnings for reasons.")
112
+ with profile(
113
+ activities=activities, record_shapes=False, profile_memory=True
114
+ ) as prof:
115
+ with record_function("FlashAttention stats"):
116
+ for _ in range(25):
117
+ o = F.scaled_dot_product_attention(query, key, value)
118
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
119
+
120
+ with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
121
+ try:
122
+ print(
123
+ f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
124
+ )
125
+ except RuntimeError:
126
+ print("EfficientAttention is not supported. See warnings for reasons.")
127
+ with profile(
128
+ activities=activities, record_shapes=False, profile_memory=True
129
+ ) as prof:
130
+ with record_function("EfficientAttention stats"):
131
+ for _ in range(25):
132
+ o = F.scaled_dot_product_attention(query, key, value)
133
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
134
+
135
+
136
+ def run_model(model, x, context):
137
+ return model(x, context)
138
+
139
+
140
+ def benchmark_transformer_blocks():
141
+ device = "cuda" if torch.cuda.is_available() else "cpu"
142
+ import torch.utils.benchmark as benchmark
143
+
144
+ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
145
+ t0 = benchmark.Timer(
146
+ stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
147
+ )
148
+ return t0.blocked_autorange().mean * 1e6
149
+
150
+ checkpoint = True
151
+ compile = False
152
+
153
+ batch_size = 32
154
+ h, w = 64, 64
155
+ context_len = 77
156
+ embed_dimension = 1024
157
+ context_dim = 1024
158
+ d_head = 64
159
+
160
+ transformer_depth = 4
161
+
162
+ n_heads = embed_dimension // d_head
163
+
164
+ dtype = torch.float16
165
+
166
+ model_native = SpatialTransformer(
167
+ embed_dimension,
168
+ n_heads,
169
+ d_head,
170
+ context_dim=context_dim,
171
+ use_linear=True,
172
+ use_checkpoint=checkpoint,
173
+ attn_type="softmax",
174
+ depth=transformer_depth,
175
+ sdp_backend=SDPBackend.FLASH_ATTENTION,
176
+ ).to(device)
177
+ model_efficient_attn = SpatialTransformer(
178
+ embed_dimension,
179
+ n_heads,
180
+ d_head,
181
+ context_dim=context_dim,
182
+ use_linear=True,
183
+ depth=transformer_depth,
184
+ use_checkpoint=checkpoint,
185
+ attn_type="softmax-xformers",
186
+ ).to(device)
187
+ if not checkpoint and compile:
188
+ print("compiling models")
189
+ model_native = torch.compile(model_native)
190
+ model_efficient_attn = torch.compile(model_efficient_attn)
191
+
192
+ x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
193
+ c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)
194
+
195
+ from torch.profiler import ProfilerActivity, profile, record_function
196
+
197
+ activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
198
+
199
+ with torch.autocast("cuda"):
200
+ print(
201
+ f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
202
+ )
203
+ print(
204
+ f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
205
+ )
206
+
207
+ print(75 * "+")
208
+ print("NATIVE")
209
+ print(75 * "+")
210
+ torch.cuda.reset_peak_memory_stats()
211
+ with profile(
212
+ activities=activities, record_shapes=False, profile_memory=True
213
+ ) as prof:
214
+ with record_function("NativeAttention stats"):
215
+ for _ in range(25):
216
+ model_native(x, c)
217
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
218
+ print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")
219
+
220
+ print(75 * "+")
221
+ print("Xformers")
222
+ print(75 * "+")
223
+ torch.cuda.reset_peak_memory_stats()
224
+ with profile(
225
+ activities=activities, record_shapes=False, profile_memory=True
226
+ ) as prof:
227
+ with record_function("xformers stats"):
228
+ for _ in range(25):
229
+ model_efficient_attn(x, c)
230
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
231
+ print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")
232
+
233
+
234
+ def test01():
235
+ # conv1x1 vs linear
236
+ from sgm.util import count_params
237
+
238
+ conv = torch.nn.Conv2d(3, 32, kernel_size=1).cuda()
239
+ print(count_params(conv))
240
+ linear = torch.nn.Linear(3, 32).cuda()
241
+ print(count_params(linear))
242
+
243
+ print(conv.weight.shape)
244
+
245
+ # use same initialization
246
+ linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
247
+ linear.bias = torch.nn.Parameter(conv.bias)
248
+
249
+ print(linear.weight.shape)
250
+
251
+ x = torch.randn(11, 3, 64, 64).cuda()
252
+
253
+ xr = einops.rearrange(x, "b c h w -> b (h w) c").contiguous()
254
+ print(xr.shape)
255
+ out_linear = linear(xr)
256
+ print(out_linear.mean(), out_linear.shape)
257
+
258
+ out_conv = conv(x)
259
+ print(out_conv.mean(), out_conv.shape)
260
+ print("done with test01.\n")
261
+
262
+
263
+ def test02():
264
+ # try cosine flash attention
265
+ import time
266
+
267
+ torch.backends.cuda.matmul.allow_tf32 = True
268
+ torch.backends.cudnn.allow_tf32 = True
269
+ torch.backends.cudnn.benchmark = True
270
+ print("testing cosine flash attention...")
271
+ DIM = 1024
272
+ SEQLEN = 4096
273
+ BS = 16
274
+
275
+ print(" softmax (vanilla) first...")
276
+ model = BasicTransformerBlock(
277
+ dim=DIM,
278
+ n_heads=16,
279
+ d_head=64,
280
+ dropout=0.0,
281
+ context_dim=None,
282
+ attn_mode="softmax",
283
+ ).cuda()
284
+ try:
285
+ x = torch.randn(BS, SEQLEN, DIM).cuda()
286
+ tic = time.time()
287
+ y = model(x)
288
+ toc = time.time()
289
+ print(y.shape, toc - tic)
290
+ except RuntimeError as e:
291
+ # likely oom
292
+ print(str(e))
293
+
294
+ print("\n now flash-cosine...")
295
+ model = BasicTransformerBlock(
296
+ dim=DIM,
297
+ n_heads=16,
298
+ d_head=64,
299
+ dropout=0.0,
300
+ context_dim=None,
301
+ attn_mode="flash-cosine",
302
+ ).cuda()
303
+ x = torch.randn(BS, SEQLEN, DIM).cuda()
304
+ tic = time.time()
305
+ y = model(x)
306
+ toc = time.time()
307
+ print(y.shape, toc - tic)
308
+ print("done with test02.\n")
309
+
310
+
311
+ if __name__ == "__main__":
312
+ # test01()
313
+ # test02()
314
+ # test03()
315
+
316
+ # benchmark_attn()
317
+ benchmark_transformer_blocks()
318
+
319
+ print("done.")
scripts/util/__init__.py ADDED
File without changes
scripts/util/detection/__init__.py ADDED
File without changes
scripts/util/detection/nsfw_and_watermark_dectection.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import clip
4
+ import numpy as np
5
+ import torch
6
+ import torchvision.transforms as T
7
+ from PIL import Image
8
+
9
+ RESOURCES_ROOT = "scripts/util/detection/"
10
+
11
+
12
+ def predict_proba(X, weights, biases):
13
+ logits = X @ weights.T + biases
14
+ proba = np.where(
15
+ logits >= 0, 1 / (1 + np.exp(-logits)), np.exp(logits) / (1 + np.exp(logits))
16
+ )
17
+ return proba.T
18
+
19
+
20
+ def load_model_weights(path: str):
21
+ model_weights = np.load(path)
22
+ return model_weights["weights"], model_weights["biases"]
23
+
24
+
25
+ def clip_process_images(images: torch.Tensor) -> torch.Tensor:
26
+ min_size = min(images.shape[-2:])
27
+ return T.Compose(
28
+ [
29
+ T.CenterCrop(min_size), # TODO: this might affect the watermark, check this
30
+ T.Resize(224, interpolation=T.InterpolationMode.BICUBIC, antialias=True),
31
+ T.Normalize(
32
+ (0.48145466, 0.4578275, 0.40821073),
33
+ (0.26862954, 0.26130258, 0.27577711),
34
+ ),
35
+ ]
36
+ )(images)
37
+
38
+
39
+ class DeepFloydDataFiltering(object):
40
+ def __init__(
41
+ self, verbose: bool = False, device: torch.device = torch.device("cpu")
42
+ ):
43
+ super().__init__()
44
+ self.verbose = verbose
45
+ self._device = None
46
+ self.clip_model, _ = clip.load("ViT-L/14", device=device)
47
+ self.clip_model.eval()
48
+
49
+ self.cpu_w_weights, self.cpu_w_biases = load_model_weights(
50
+ os.path.join(RESOURCES_ROOT, "w_head_v1.npz")
51
+ )
52
+ self.cpu_p_weights, self.cpu_p_biases = load_model_weights(
53
+ os.path.join(RESOURCES_ROOT, "p_head_v1.npz")
54
+ )
55
+ self.w_threshold, self.p_threshold = 0.5, 0.5
56
+
57
+ @torch.inference_mode()
58
+ def __call__(self, images: torch.Tensor) -> torch.Tensor:
59
+ imgs = clip_process_images(images)
60
+ if self._device is None:
61
+ self._device = next(p for p in self.clip_model.parameters()).device
62
+ image_features = self.clip_model.encode_image(imgs.to(self._device))
63
+ image_features = image_features.detach().cpu().numpy().astype(np.float16)
64
+ p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases)
65
+ w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases)
66
+ print(f"p_pred = {p_pred}, w_pred = {w_pred}") if self.verbose else None
67
+ query = p_pred > self.p_threshold
68
+ if query.sum() > 0:
69
+ print(f"Hit for p_threshold: {p_pred}") if self.verbose else None
70
+ images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query])
71
+ query = w_pred > self.w_threshold
72
+ if query.sum() > 0:
73
+ print(f"Hit for w_threshold: {w_pred}") if self.verbose else None
74
+ images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query])
75
+ return images
76
+
77
+
78
+ def load_img(path: str) -> torch.Tensor:
79
+ image = Image.open(path)
80
+ if not image.mode == "RGB":
81
+ image = image.convert("RGB")
82
+ image_transforms = T.Compose(
83
+ [
84
+ T.ToTensor(),
85
+ ]
86
+ )
87
+ return image_transforms(image)[None, ...]
88
+
89
+
90
+ def test(root):
91
+ from einops import rearrange
92
+
93
+ filter = DeepFloydDataFiltering(verbose=True)
94
+ for p in os.listdir((root)):
95
+ print(f"running on {p}...")
96
+ img = load_img(os.path.join(root, p))
97
+ filtered_img = filter(img)
98
+ filtered_img = rearrange(
99
+ 255.0 * (filtered_img.numpy())[0], "c h w -> h w c"
100
+ ).astype(np.uint8)
101
+ Image.fromarray(filtered_img).save(
102
+ os.path.join(root, f"{os.path.splitext(p)[0]}-filtered.jpg")
103
+ )
104
+
105
+
106
+ if __name__ == "__main__":
107
+ import fire
108
+
109
+ fire.Fire(test)
110
+ print("done.")
scripts/util/detection/p_head_v1.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4653a64d5f85d8d4c5f6c5ec175f1c5c5e37db8f38d39b2ed8b5979da7fdc76
3
+ size 3588
scripts/util/detection/w_head_v1.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6af23687aa347073e692025f405ccc48c14aadc5dbe775b3312041006d496d1
3
+ size 3588
sgm/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .models import AutoencodingEngine, DiffusionEngine
2
+ from .util import get_configs_path, instantiate_from_config
3
+
4
+ __version__ = "0.1.0"
sgm/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .dataset import StableDataModuleFromConfig
sgm/data/cam_utils.py ADDED
@@ -0,0 +1,1253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Common camera utilities
3
+ '''
4
+
5
+ import math
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from pytorch3d.renderer import PerspectiveCameras
10
+ from pytorch3d.renderer.cameras import look_at_view_transform
11
+ from pytorch3d.renderer.implicit.raysampling import _xy_to_ray_bundle
12
+
13
+ class RelativeCameraLoader(nn.Module):
14
+ def __init__(self,
15
+ query_batch_size=1,
16
+ rand_query=True,
17
+ relative=True,
18
+ center_at_origin=False,
19
+ ):
20
+ super().__init__()
21
+
22
+ self.query_batch_size = query_batch_size
23
+ self.rand_query = rand_query
24
+ self.relative = relative
25
+ self.center_at_origin = center_at_origin
26
+
27
+ def plot_cameras(self, cameras_1, cameras_2):
28
+ '''
29
+ Helper function to plot cameras
30
+
31
+ Args:
32
+ cameras_1 (PyTorch3D camera): cameras object to plot
33
+ cameras_2 (PyTorch3D camera): cameras object to plot
34
+ '''
35
+ from pytorch3d.vis.plotly_vis import AxisArgs, plot_batch_individually, plot_scene
36
+ import plotly.graph_objects as go
37
+ plotlyplot = plot_scene(
38
+ {
39
+ 'scene_batch': {
40
+ 'cameras': cameras_1.to('cpu'),
41
+ 'rel_cameras': cameras_2.to('cpu'),
42
+ }
43
+ },
44
+ camera_scale=.5,#0.05,
45
+ pointcloud_max_points=10000,
46
+ pointcloud_marker_size=1.0,
47
+ raybundle_max_rays=100
48
+ )
49
+ plotlyplot.show()
50
+
51
+ def concat_cameras(self, camera_list):
52
+ '''
53
+ Returns a concatenation of a list of cameras
54
+
55
+ Args:
56
+ camera_list (List[PyTorch3D camera]): a list of PyTorch3D cameras
57
+ '''
58
+ R_list, T_list, f_list, c_list, size_list = [], [], [], [], []
59
+ for cameras in camera_list:
60
+ R_list.append(cameras.R)
61
+ T_list.append(cameras.T)
62
+ f_list.append(cameras.focal_length)
63
+ c_list.append(cameras.principal_point)
64
+ size_list.append(cameras.image_size)
65
+
66
+ camera_slice = PerspectiveCameras(
67
+ R = torch.cat(R_list),
68
+ T = torch.cat(T_list),
69
+ focal_length = torch.cat(f_list),
70
+ principal_point = torch.cat(c_list),
71
+ image_size = torch.cat(size_list),
72
+ device = camera_list[0].device,
73
+ )
74
+ return camera_slice
75
+
76
+ def get_camera_slice(self, scene_cameras, indices):
77
+ '''
78
+ Return a subset of cameras from a super set given indices
79
+
80
+ Args:
81
+ scene_cameras (PyTorch3D Camera): cameras object
82
+ indices (tensor or List): a flat list or tensor of indices
83
+
84
+ Returns:
85
+ camera_slice (PyTorch3D Camera) - cameras subset
86
+ '''
87
+ camera_slice = PerspectiveCameras(
88
+ R = scene_cameras.R[indices],
89
+ T = scene_cameras.T[indices],
90
+ focal_length = scene_cameras.focal_length[indices],
91
+ principal_point = scene_cameras.principal_point[indices],
92
+ image_size = scene_cameras.image_size[indices],
93
+ device = scene_cameras.device,
94
+ )
95
+ return camera_slice
96
+
97
+
98
+ def get_relative_camera(self, scene_cameras:PerspectiveCameras, query_idx, center_at_origin=False):
99
+ """
100
+ Transform context cameras relative to a base query camera
101
+
102
+ Args:
103
+ scene_cameras (PyTorch3D Camera): cameras object
104
+ query_idx (tensor or List): a length 1 list defining query idx
105
+
106
+ Returns:
107
+ cams_relative (PyTorch3D Camera): cameras object relative to query camera
108
+ """
109
+
110
+ query_camera = self.get_camera_slice(scene_cameras, query_idx)
111
+ query_world2view = query_camera.get_world_to_view_transform()
112
+ all_world2view = scene_cameras.get_world_to_view_transform()
113
+
114
+ if center_at_origin:
115
+ identity_cam = PerspectiveCameras(device=scene_cameras.device, R=query_camera.R, T=query_camera.T)
116
+ else:
117
+ T = torch.zeros((1, 3))
118
+ identity_cam = PerspectiveCameras(device=scene_cameras.device, R=query_camera.R, T=T)
119
+
120
+ identity_world2view = identity_cam.get_world_to_view_transform()
121
+
122
+ # compose the relative transformation as g_i^{-1} g_j
123
+ relative_world2view = identity_world2view.inverse().compose(all_world2view)
124
+
125
+ # generate a camera from the relative transform
126
+ relative_matrix = relative_world2view.get_matrix()
127
+ cams_relative = PerspectiveCameras(
128
+ R = relative_matrix[:, :3, :3],
129
+ T = relative_matrix[:, 3, :3],
130
+ focal_length = scene_cameras.focal_length,
131
+ principal_point = scene_cameras.principal_point,
132
+ image_size = scene_cameras.image_size,
133
+ device = scene_cameras.device,
134
+ )
135
+ return cams_relative
136
+
137
+ def forward(self, scene_cameras, scene_rgb=None, scene_masks=None, query_idx=None, context_size=3, context_idx=None, return_context=False):
138
+ '''
139
+ Return a sampled batch of query and context cameras (used in training)
140
+
141
+ Args:
142
+ scene_cameras (PyTorch3D Camera): a batch of PyTorch3D cameras
143
+ scene_rgb (Tensor): a batch of rgb
144
+ scene_masks (Tensor): a batch of masks (optional)
145
+ query_idx (List or Tensor): desired query idx (optional)
146
+ context_size (int): number of views for context
147
+
148
+ Returns:
149
+ query_cameras, query_rgb, query_masks: random query view
150
+ context_cameras, context_rgb, context_masks: context views
151
+ '''
152
+
153
+ if query_idx is None:
154
+ query_idx = [0]
155
+ if self.rand_query:
156
+ rand = torch.randperm(len(scene_cameras))
157
+ query_idx = rand[:1]
158
+
159
+ if context_idx is None:
160
+ rand = torch.randperm(len(scene_cameras))
161
+ context_idx = rand[:context_size]
162
+
163
+
164
+ if self.relative:
165
+ rel_cameras = self.get_relative_camera(scene_cameras, query_idx, center_at_origin=self.center_at_origin)
166
+ else:
167
+ rel_cameras = scene_cameras
168
+
169
+ query_cameras = self.get_camera_slice(rel_cameras, query_idx)
170
+ query_rgb = None
171
+ if scene_rgb is not None:
172
+ query_rgb = scene_rgb[query_idx]
173
+ query_masks = None
174
+ if scene_masks is not None:
175
+ query_masks = scene_masks[query_idx]
176
+
177
+ context_cameras = self.get_camera_slice(rel_cameras, context_idx)
178
+ context_rgb = None
179
+ if scene_rgb is not None:
180
+ context_rgb = scene_rgb[context_idx]
181
+ context_masks = None
182
+ if scene_masks is not None:
183
+ context_masks = scene_masks[context_idx]
184
+
185
+ if return_context:
186
+ return query_cameras, query_rgb, query_masks, context_cameras, context_rgb, context_masks, context_idx
187
+ return query_cameras, query_rgb, query_masks, context_cameras, context_rgb, context_masks
188
+
189
+
190
+ def get_interpolated_path(cameras: PerspectiveCameras, n=50, method='circle', theta_offset_max=0.0):
191
+ '''
192
+ Given a camera object containing a set of cameras, fit a circle and get
193
+ interpolated cameras
194
+
195
+ Args:
196
+ cameras (PyTorch3D Camera): input camera object
197
+ n (int): length of cameras in new path
198
+ method (str): 'circle'
199
+ theta_offset_max (int): max camera jitter in radians
200
+
201
+ Returns:
202
+ path_cameras (PyTorch3D Camera): interpolated cameras
203
+ '''
204
+ device = cameras.device
205
+ cameras = cameras.cpu()
206
+
207
+ if method == 'circle':
208
+
209
+ #@ https://meshlogic.github.io/posts/jupyter/curve-fitting/fitting-a-circle-to-cluster-of-3d-points/
210
+ #@ Fit plane
211
+ P = cameras.get_camera_center().cpu()
212
+ P_mean = P.mean(axis=0)
213
+ P_centered = P - P_mean
214
+ U,s,V = torch.linalg.svd(P_centered)
215
+ normal = V[2,:]
216
+ if (normal*2 - P_mean).norm() < (normal - P_mean).norm():
217
+ normal = - normal
218
+ d = -torch.dot(P_mean, normal) # d = -<p,n>
219
+
220
+ #@ Project pts to plane
221
+ P_xy = rodrigues_rot(P_centered, normal, torch.tensor([0.0,0.0,1.0]))
222
+
223
+ #@ Fit circle in 2D
224
+ xc, yc, r = fit_circle_2d(P_xy[:,0], P_xy[:,1])
225
+ t = torch.linspace(0, 2*math.pi, 100)
226
+ xx = xc + r*torch.cos(t)
227
+ yy = yc + r*torch.sin(t)
228
+
229
+ #@ Project circle to 3D
230
+ C = rodrigues_rot(torch.tensor([xc,yc,0.0]), torch.tensor([0.0,0.0,1.0]), normal) + P_mean
231
+ C = C.flatten()
232
+
233
+ #@ Get pts n 3D
234
+ t = torch.linspace(0, 2*math.pi, n)
235
+ u = P[0] - C
236
+ new_camera_centers = generate_circle_by_vectors(t, C, r, normal, u)
237
+
238
+ #@ OPTIONAL THETA OFFSET
239
+ if theta_offset_max > 0.0:
240
+ aug_theta = (torch.rand((new_camera_centers.shape[0])) * (2*theta_offset_max)) - theta_offset_max
241
+ new_camera_centers = rodrigues_rot2(new_camera_centers, normal, aug_theta)
242
+
243
+ #@ Get camera look at
244
+ new_camera_look_at = get_nearest_centroid(cameras)
245
+
246
+ #@ Get R T
247
+ up_vec = -normal
248
+ R, T = look_at_view_transform(eye=new_camera_centers, at=new_camera_look_at.unsqueeze(0), up=up_vec.unsqueeze(0), device=cameras.device)
249
+ else:
250
+ raise NotImplementedError
251
+
252
+ c = (cameras.principal_point).mean(dim=0, keepdim=True).expand(R.shape[0],-1)
253
+ f = (cameras.focal_length).mean(dim=0, keepdim=True).expand(R.shape[0],-1)
254
+ image_size = cameras.image_size[:1].expand(R.shape[0],-1)
255
+
256
+
257
+ path_cameras = PerspectiveCameras(R=R,T=T,focal_length=f,principal_point=c,image_size=image_size, device=device)
258
+ cameras = cameras.to(device)
259
+ return path_cameras
260
+
261
+ def np_normalize(vec, axis=-1):
262
+ vec = vec / (np.linalg.norm(vec, axis=axis, keepdims=True) + 1e-9)
263
+ return vec
264
+
265
+
266
+ #@ https://meshlogic.github.io/posts/jupyter/curve-fitting/fitting-a-circle-to-cluster-of-3d-points/
267
+ #-------------------------------------------------------------------------------
268
+ # Generate points on circle
269
+ # P(t) = r*cos(t)*u + r*sin(t)*(n x u) + C
270
+ #-------------------------------------------------------------------------------
271
+ def generate_circle_by_vectors(t, C, r, n, u):
272
+ n = n/torch.linalg.norm(n)
273
+ u = u/torch.linalg.norm(u)
274
+ P_circle = r*torch.cos(t)[:,None]*u + r*torch.sin(t)[:,None]*torch.cross(n,u) + C
275
+ return P_circle
276
+
277
+ #@ https://meshlogic.github.io/posts/jupyter/curve-fitting/fitting-a-circle-to-cluster-of-3d-points/
278
+ #-------------------------------------------------------------------------------
279
+ # FIT CIRCLE 2D
280
+ # - Find center [xc, yc] and radius r of circle fitting to set of 2D points
281
+ # - Optionally specify weights for points
282
+ #
283
+ # - Implicit circle function:
284
+ # (x-xc)^2 + (y-yc)^2 = r^2
285
+ # (2*xc)*x + (2*yc)*y + (r^2-xc^2-yc^2) = x^2+y^2
286
+ # c[0]*x + c[1]*y + c[2] = x^2+y^2
287
+ #
288
+ # - Solution by method of least squares:
289
+ # A*c = b, c' = argmin(||A*c - b||^2)
290
+ # A = [x y 1], b = [x^2+y^2]
291
+ #-------------------------------------------------------------------------------
292
+ def fit_circle_2d(x, y, w=[]):
293
+
294
+ A = torch.stack([x, y, torch.ones(len(x))]).T
295
+ b = x**2 + y**2
296
+
297
+ # Modify A,b for weighted least squares
298
+ if len(w) == len(x):
299
+ W = torch.diag(w)
300
+ A = torch.dot(W,A)
301
+ b = torch.dot(W,b)
302
+
303
+ # Solve by method of least squares
304
+ c = torch.linalg.lstsq(A,b,rcond=None)[0]
305
+
306
+ # Get circle parameters from solution c
307
+ xc = c[0]/2
308
+ yc = c[1]/2
309
+ r = torch.sqrt(c[2] + xc**2 + yc**2)
310
+ return xc, yc, r
311
+
312
+ #@ https://meshlogic.github.io/posts/jupyter/curve-fitting/fitting-a-circle-to-cluster-of-3d-points/
313
+ #-------------------------------------------------------------------------------
314
+ # RODRIGUES ROTATION
315
+ # - Rotate given points based on a starting and ending vector
316
+ # - Axis k and angle of rotation theta given by vectors n0,n1
317
+ # P_rot = P*cos(theta) + (k x P)*sin(theta) + k*<k,P>*(1-cos(theta))
318
+ #-------------------------------------------------------------------------------
319
+ def rodrigues_rot(P, n0, n1):
320
+
321
+ # If P is only 1d array (coords of single point), fix it to be matrix
322
+ if P.ndim == 1:
323
+ P = P[None,...]
324
+
325
+ # Get vector of rotation k and angle theta
326
+ n0 = n0/torch.linalg.norm(n0)
327
+ n1 = n1/torch.linalg.norm(n1)
328
+ k = torch.cross(n0,n1)
329
+ k = k/torch.linalg.norm(k)
330
+ theta = torch.arccos(torch.dot(n0,n1))
331
+
332
+ # Compute rotated points
333
+ P_rot = torch.zeros((len(P),3))
334
+ for i in range(len(P)):
335
+ P_rot[i] = P[i]*torch.cos(theta) + torch.cross(k,P[i])*torch.sin(theta) + k*torch.dot(k,P[i])*(1-torch.cos(theta))
336
+
337
+ return P_rot
338
+
339
+ def rodrigues_rot2(P, n1, theta):
340
+ '''
341
+ Rotate points P wrt axis k by theta radians
342
+ '''
343
+
344
+ # If P is only 1d array (coords of single point), fix it to be matrix
345
+ if P.ndim == 1:
346
+ P = P[None,...]
347
+
348
+ k = torch.cross(P, n1.unsqueeze(0))
349
+ k = k/torch.linalg.norm(k)
350
+
351
+ # Compute rotated points
352
+ P_rot = torch.zeros((len(P),3))
353
+ for i in range(len(P)):
354
+ P_rot[i] = P[i]*torch.cos(theta[i]) + torch.cross(k[i],P[i])*torch.sin(theta[i]) + k[i]*torch.dot(k[i],P[i])*(1-torch.cos(theta[i]))
355
+
356
+ return P_rot
357
+
358
+ #@ https://meshlogic.github.io/posts/jupyter/curve-fitting/fitting-a-circle-to-cluster-of-3d-points/
359
+ #-------------------------------------------------------------------------------
360
+ # ANGLE BETWEEN
361
+ # - Get angle between vectors u,v with sign based on plane with unit normal n
362
+ #-------------------------------------------------------------------------------
363
+ def angle_between(u, v, n=None):
364
+ if n is None:
365
+ return torch.arctan2(torch.linalg.norm(torch.cross(u,v)), torch.dot(u,v))
366
+ else:
367
+ return torch.arctan2(torch.dot(n,torch.cross(u,v)), torch.dot(u,v))
368
+
369
+ #@ https://www.crewes.org/Documents/ResearchReports/2010/CRR201032.pdf
370
+ def get_nearest_centroid(cameras: PerspectiveCameras):
371
+ '''
372
+ Given PyTorch3D cameras, find the nearest point along their principal ray
373
+ '''
374
+
375
+ #@ GET CAMERA CENTERS AND DIRECTIONS
376
+ camera_centers = cameras.get_camera_center()
377
+
378
+ c_mean = (cameras.principal_point).mean(dim=0)
379
+ xy_grid = c_mean.unsqueeze(0).unsqueeze(0)
380
+ ray_vis = _xy_to_ray_bundle(cameras, xy_grid.expand(len(cameras),-1,-1), 1.0, 15.0, 20, True)
381
+ camera_directions = ray_vis.directions
382
+
383
+ #@ CONSTRUCT MATRICIES
384
+ A = torch.zeros((3*len(cameras)), len(cameras)+3)
385
+ b = torch.zeros((3*len(cameras), 1))
386
+ A[:,:3] = torch.eye(3).repeat(len(cameras),1)
387
+ for ci in range(len(camera_directions)):
388
+ A[3*ci:3*ci+3, ci+3] = -camera_directions[ci]
389
+ b[3*ci:3*ci+3, 0] = camera_centers[ci]
390
+ #' A (3*N, 3*N+3) b (3*N, 1)
391
+
392
+ #@ SVD
393
+ U, s, VT = torch.linalg.svd(A)
394
+ Sinv = torch.diag(1/s)
395
+ if len(s) < 3*len(cameras):
396
+ Sinv = torch.cat((Sinv, torch.zeros((Sinv.shape[0], 3*len(cameras) - Sinv.shape[1]), device=Sinv.device)), dim=1)
397
+ x = torch.matmul(VT.T, torch.matmul(Sinv,torch.matmul(U.T, b)))
398
+
399
+ centroid = x[:3,0]
400
+ return centroid
401
+
402
+
403
+ def get_angles(target_camera: PerspectiveCameras, context_cameras: PerspectiveCameras, centroid=None):
404
+ '''
405
+ Get angles between cameras wrt a centroid
406
+
407
+ Args:
408
+ target_camera (Pytorch3D Camera): a camera object with a single camera
409
+ context_cameras (PyTorch3D Camera): a camera object
410
+
411
+ Returns:
412
+ theta_deg (Tensor): a tensor containing angles in degrees
413
+ '''
414
+ a1 = target_camera.get_camera_center()
415
+ b1 = context_cameras.get_camera_center()
416
+
417
+ a = a1 - centroid.unsqueeze(0)
418
+ a = a.expand(len(context_cameras), -1)
419
+ b = b1 - centroid.unsqueeze(0)
420
+
421
+ ab_dot = (a*b).sum(dim=-1)
422
+ theta = torch.acos((ab_dot)/(torch.linalg.norm(a, dim=-1) * torch.linalg.norm(b, dim=-1)))
423
+ theta_deg = theta * 180 / math.pi
424
+
425
+ return theta_deg
426
+
427
+
428
+ import math
429
+ from typing import List, Literal, Optional, Tuple
430
+
431
+ import numpy as np
432
+ import torch
433
+ from jaxtyping import Float
434
+ from numpy.typing import NDArray
435
+ from torch import Tensor
436
+
437
+ _EPS = np.finfo(float).eps * 4.0
438
+
439
+
440
+ def unit_vector(data: NDArray, axis: Optional[int] = None) -> np.ndarray:
441
+ """Return ndarray normalized by length, i.e. Euclidean norm, along axis.
442
+
443
+ Args:
444
+ axis: the axis along which to normalize into unit vector
445
+ out: where to write out the data to. If None, returns a new np ndarray
446
+ """
447
+ data = np.array(data, dtype=np.float64, copy=True)
448
+ if data.ndim == 1:
449
+ data /= math.sqrt(np.dot(data, data))
450
+ return data
451
+ length = np.atleast_1d(np.sum(data * data, axis))
452
+ np.sqrt(length, length)
453
+ if axis is not None:
454
+ length = np.expand_dims(length, axis)
455
+ data /= length
456
+ return data
457
+
458
+
459
+ def quaternion_from_matrix(matrix: NDArray, isprecise: bool = False) -> np.ndarray:
460
+ """Return quaternion from rotation matrix.
461
+
462
+ Args:
463
+ matrix: rotation matrix to obtain quaternion
464
+ isprecise: if True, input matrix is assumed to be precise rotation matrix and a faster algorithm is used.
465
+ """
466
+ M = np.array(matrix, dtype=np.float64, copy=False)[:4, :4]
467
+ if isprecise:
468
+ q = np.empty((4,))
469
+ t = np.trace(M)
470
+ if t > M[3, 3]:
471
+ q[0] = t
472
+ q[3] = M[1, 0] - M[0, 1]
473
+ q[2] = M[0, 2] - M[2, 0]
474
+ q[1] = M[2, 1] - M[1, 2]
475
+ else:
476
+ i, j, k = 1, 2, 3
477
+ if M[1, 1] > M[0, 0]:
478
+ i, j, k = 2, 3, 1
479
+ if M[2, 2] > M[i, i]:
480
+ i, j, k = 3, 1, 2
481
+ t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3]
482
+ q[i] = t
483
+ q[j] = M[i, j] + M[j, i]
484
+ q[k] = M[k, i] + M[i, k]
485
+ q[3] = M[k, j] - M[j, k]
486
+ q *= 0.5 / math.sqrt(t * M[3, 3])
487
+ else:
488
+ m00 = M[0, 0]
489
+ m01 = M[0, 1]
490
+ m02 = M[0, 2]
491
+ m10 = M[1, 0]
492
+ m11 = M[1, 1]
493
+ m12 = M[1, 2]
494
+ m20 = M[2, 0]
495
+ m21 = M[2, 1]
496
+ m22 = M[2, 2]
497
+ # symmetric matrix K
498
+ K = [
499
+ [m00 - m11 - m22, 0.0, 0.0, 0.0],
500
+ [m01 + m10, m11 - m00 - m22, 0.0, 0.0],
501
+ [m02 + m20, m12 + m21, m22 - m00 - m11, 0.0],
502
+ [m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22],
503
+ ]
504
+ K = np.array(K)
505
+ K /= 3.0
506
+ # quaternion is eigenvector of K that corresponds to largest eigenvalue
507
+ w, V = np.linalg.eigh(K)
508
+ q = V[np.array([3, 0, 1, 2]), np.argmax(w)]
509
+ if q[0] < 0.0:
510
+ np.negative(q, q)
511
+ return q
512
+
513
+
514
+ def quaternion_slerp(
515
+ quat0: NDArray, quat1: NDArray, fraction: float, spin: int = 0, shortestpath: bool = True
516
+ ) -> np.ndarray:
517
+ """Return spherical linear interpolation between two quaternions.
518
+ Args:
519
+ quat0: first quaternion
520
+ quat1: second quaternion
521
+ fraction: how much to interpolate between quat0 vs quat1 (if 0, closer to quat0; if 1, closer to quat1)
522
+ spin: how much of an additional spin to place on the interpolation
523
+ shortestpath: whether to return the short or long path to rotation
524
+ """
525
+ q0 = unit_vector(quat0[:4])
526
+ q1 = unit_vector(quat1[:4])
527
+ if q0 is None or q1 is None:
528
+ raise ValueError("Input quaternions invalid.")
529
+ if fraction == 0.0:
530
+ return q0
531
+ if fraction == 1.0:
532
+ return q1
533
+ d = np.dot(q0, q1)
534
+ if abs(abs(d) - 1.0) < _EPS:
535
+ return q0
536
+ if shortestpath and d < 0.0:
537
+ # invert rotation
538
+ d = -d
539
+ np.negative(q1, q1)
540
+ angle = math.acos(d) + spin * math.pi
541
+ if abs(angle) < _EPS:
542
+ return q0
543
+ isin = 1.0 / math.sin(angle)
544
+ q0 *= math.sin((1.0 - fraction) * angle) * isin
545
+ q1 *= math.sin(fraction * angle) * isin
546
+ q0 += q1
547
+ return q0
548
+
549
+
550
+ def quaternion_matrix(quaternion: NDArray) -> np.ndarray:
551
+ """Return homogeneous rotation matrix from quaternion.
552
+
553
+ Args:
554
+ quaternion: value to convert to matrix
555
+ """
556
+ q = np.array(quaternion, dtype=np.float64, copy=True)
557
+ n = np.dot(q, q)
558
+ if n < _EPS:
559
+ return np.identity(4)
560
+ q *= math.sqrt(2.0 / n)
561
+ q = np.outer(q, q)
562
+ return np.array(
563
+ [
564
+ [1.0 - q[2, 2] - q[3, 3], q[1, 2] - q[3, 0], q[1, 3] + q[2, 0], 0.0],
565
+ [q[1, 2] + q[3, 0], 1.0 - q[1, 1] - q[3, 3], q[2, 3] - q[1, 0], 0.0],
566
+ [q[1, 3] - q[2, 0], q[2, 3] + q[1, 0], 1.0 - q[1, 1] - q[2, 2], 0.0],
567
+ [0.0, 0.0, 0.0, 1.0],
568
+ ]
569
+ )
570
+
571
+
572
+ def get_interpolated_poses(pose_a: NDArray, pose_b: NDArray, steps: int = 10) -> List[float]:
573
+ """Return interpolation of poses with specified number of steps.
574
+ Args:
575
+ pose_a: first pose
576
+ pose_b: second pose
577
+ steps: number of steps the interpolated pose path should contain
578
+ """
579
+
580
+ quat_a = quaternion_from_matrix(pose_a[:3, :3])
581
+ quat_b = quaternion_from_matrix(pose_b[:3, :3])
582
+
583
+ ts = np.linspace(0, 1, steps)
584
+ quats = [quaternion_slerp(quat_a, quat_b, t) for t in ts]
585
+ trans = [(1 - t) * pose_a[:3, 3] + t * pose_b[:3, 3] for t in ts]
586
+
587
+ poses_ab = []
588
+ for quat, tran in zip(quats, trans):
589
+ pose = np.identity(4)
590
+ pose[:3, :3] = quaternion_matrix(quat)[:3, :3]
591
+ pose[:3, 3] = tran
592
+ poses_ab.append(pose[:3])
593
+ return poses_ab
594
+
595
+
596
+ def get_interpolated_k(
597
+ k_a: Float[Tensor, "3 3"], k_b: Float[Tensor, "3 3"], steps: int = 10
598
+ ) -> List[Float[Tensor, "3 4"]]:
599
+ """
600
+ Returns interpolated path between two camera poses with specified number of steps.
601
+
602
+ Args:
603
+ k_a: camera matrix 1
604
+ k_b: camera matrix 2
605
+ steps: number of steps the interpolated pose path should contain
606
+
607
+ Returns:
608
+ List of interpolated camera poses
609
+ """
610
+ Ks: List[Float[Tensor, "3 3"]] = []
611
+ ts = np.linspace(0, 1, steps)
612
+ for t in ts:
613
+ new_k = k_a * (1.0 - t) + k_b * t
614
+ Ks.append(new_k)
615
+ return Ks
616
+
617
+
618
+ def get_ordered_poses_and_k(
619
+ poses: Float[Tensor, "num_poses 3 4"],
620
+ Ks: Float[Tensor, "num_poses 3 3"],
621
+ ) -> Tuple[Float[Tensor, "num_poses 3 4"], Float[Tensor, "num_poses 3 3"]]:
622
+ """
623
+ Returns ordered poses and intrinsics by euclidian distance between poses.
624
+
625
+ Args:
626
+ poses: list of camera poses
627
+ Ks: list of camera intrinsics
628
+
629
+ Returns:
630
+ tuple of ordered poses and intrinsics
631
+
632
+ """
633
+
634
+ poses_num = len(poses)
635
+
636
+ ordered_poses = torch.unsqueeze(poses[0], 0)
637
+ ordered_ks = torch.unsqueeze(Ks[0], 0)
638
+
639
+ # remove the first pose from poses
640
+ poses = poses[1:]
641
+ Ks = Ks[1:]
642
+
643
+ for _ in range(poses_num - 1):
644
+ distances = torch.norm(ordered_poses[-1][:, 3] - poses[:, :, 3], dim=1)
645
+ idx = torch.argmin(distances)
646
+ ordered_poses = torch.cat((ordered_poses, torch.unsqueeze(poses[idx], 0)), dim=0)
647
+ ordered_ks = torch.cat((ordered_ks, torch.unsqueeze(Ks[idx], 0)), dim=0)
648
+ poses = torch.cat((poses[0:idx], poses[idx + 1 :]), dim=0)
649
+ Ks = torch.cat((Ks[0:idx], Ks[idx + 1 :]), dim=0)
650
+
651
+ return ordered_poses, ordered_ks
652
+
653
+
654
+ def get_interpolated_poses_many(
655
+ poses: Float[Tensor, "num_poses 3 4"],
656
+ Ks: Float[Tensor, "num_poses 3 3"],
657
+ steps_per_transition: int = 10,
658
+ order_poses: bool = False,
659
+ ) -> Tuple[Float[Tensor, "num_poses 3 4"], Float[Tensor, "num_poses 3 3"]]:
660
+ """Return interpolated poses for many camera poses.
661
+
662
+ Args:
663
+ poses: list of camera poses
664
+ Ks: list of camera intrinsics
665
+ steps_per_transition: number of steps per transition
666
+ order_poses: whether to order poses by euclidian distance
667
+
668
+ Returns:
669
+ tuple of new poses and intrinsics
670
+ """
671
+ traj = []
672
+ k_interp = []
673
+
674
+ if order_poses:
675
+ poses, Ks = get_ordered_poses_and_k(poses, Ks)
676
+
677
+ for idx in range(poses.shape[0] - 1):
678
+ pose_a = poses[idx].cpu().numpy()
679
+ pose_b = poses[idx + 1].cpu().numpy()
680
+ poses_ab = get_interpolated_poses(pose_a, pose_b, steps=steps_per_transition)
681
+ traj += poses_ab
682
+ k_interp += get_interpolated_k(Ks[idx], Ks[idx + 1], steps=steps_per_transition)
683
+
684
+ traj = np.stack(traj, axis=0)
685
+ k_interp = torch.stack(k_interp, dim=0)
686
+
687
+ return torch.tensor(traj, dtype=torch.float32), torch.tensor(k_interp, dtype=torch.float32)
688
+
689
+
690
+ def normalize(x: torch.Tensor) -> Float[Tensor, "*batch"]:
691
+ """Returns a normalized vector."""
692
+ return x / torch.linalg.norm(x)
693
+
694
+
695
+ def normalize_with_norm(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
696
+ """Normalize tensor along axis and return normalized value with norms.
697
+
698
+ Args:
699
+ x: tensor to normalize.
700
+ dim: axis along which to normalize.
701
+
702
+ Returns:
703
+ Tuple of normalized tensor and corresponding norm.
704
+ """
705
+
706
+ norm = torch.maximum(torch.linalg.vector_norm(x, dim=dim, keepdims=True), torch.tensor([_EPS]).to(x))
707
+ return x / norm, norm
708
+
709
+
710
+ def viewmatrix(lookat: torch.Tensor, up: torch.Tensor, pos: torch.Tensor) -> Float[Tensor, "*batch"]:
711
+ """Returns a camera transformation matrix.
712
+
713
+ Args:
714
+ lookat: The direction the camera is looking.
715
+ up: The upward direction of the camera.
716
+ pos: The position of the camera.
717
+
718
+ Returns:
719
+ A camera transformation matrix.
720
+ """
721
+ vec2 = normalize(lookat)
722
+ vec1_avg = normalize(up)
723
+ vec0 = normalize(torch.cross(vec1_avg, vec2))
724
+ vec1 = normalize(torch.cross(vec2, vec0))
725
+ m = torch.stack([vec0, vec1, vec2, pos], 1)
726
+ return m
727
+
728
+
729
+ def get_distortion_params(
730
+ k1: float = 0.0,
731
+ k2: float = 0.0,
732
+ k3: float = 0.0,
733
+ k4: float = 0.0,
734
+ p1: float = 0.0,
735
+ p2: float = 0.0,
736
+ ) -> Float[Tensor, "*batch"]:
737
+ """Returns a distortion parameters matrix.
738
+
739
+ Args:
740
+ k1: The first radial distortion parameter.
741
+ k2: The second radial distortion parameter.
742
+ k3: The third radial distortion parameter.
743
+ k4: The fourth radial distortion parameter.
744
+ p1: The first tangential distortion parameter.
745
+ p2: The second tangential distortion parameter.
746
+ Returns:
747
+ torch.Tensor: A distortion parameters matrix.
748
+ """
749
+ return torch.Tensor([k1, k2, k3, k4, p1, p2])
750
+
751
+
752
+ def _compute_residual_and_jacobian(
753
+ x: torch.Tensor,
754
+ y: torch.Tensor,
755
+ xd: torch.Tensor,
756
+ yd: torch.Tensor,
757
+ distortion_params: torch.Tensor,
758
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
759
+ """Auxiliary function of radial_and_tangential_undistort() that computes residuals and jacobians.
760
+ Adapted from MultiNeRF:
761
+ https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/camera_utils.py#L427-L474
762
+
763
+ Args:
764
+ x: The updated x coordinates.
765
+ y: The updated y coordinates.
766
+ xd: The distorted x coordinates.
767
+ yd: The distorted y coordinates.
768
+ distortion_params: The distortion parameters [k1, k2, k3, k4, p1, p2].
769
+
770
+ Returns:
771
+ The residuals (fx, fy) and jacobians (fx_x, fx_y, fy_x, fy_y).
772
+ """
773
+
774
+ k1 = distortion_params[..., 0]
775
+ k2 = distortion_params[..., 1]
776
+ k3 = distortion_params[..., 2]
777
+ k4 = distortion_params[..., 3]
778
+ p1 = distortion_params[..., 4]
779
+ p2 = distortion_params[..., 5]
780
+
781
+ # let r(x, y) = x^2 + y^2;
782
+ # d(x, y) = 1 + k1 * r(x, y) + k2 * r(x, y) ^2 + k3 * r(x, y)^3 +
783
+ # k4 * r(x, y)^4;
784
+ r = x * x + y * y
785
+ d = 1.0 + r * (k1 + r * (k2 + r * (k3 + r * k4)))
786
+
787
+ # The perfect projection is:
788
+ # xd = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2);
789
+ # yd = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2);
790
+ #
791
+ # Let's define
792
+ #
793
+ # fx(x, y) = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2) - xd;
794
+ # fy(x, y) = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2) - yd;
795
+ #
796
+ # We are looking for a solution that satisfies
797
+ # fx(x, y) = fy(x, y) = 0;
798
+ fx = d * x + 2 * p1 * x * y + p2 * (r + 2 * x * x) - xd
799
+ fy = d * y + 2 * p2 * x * y + p1 * (r + 2 * y * y) - yd
800
+
801
+ # Compute derivative of d over [x, y]
802
+ d_r = k1 + r * (2.0 * k2 + r * (3.0 * k3 + r * 4.0 * k4))
803
+ d_x = 2.0 * x * d_r
804
+ d_y = 2.0 * y * d_r
805
+
806
+ # Compute derivative of fx over x and y.
807
+ fx_x = d + d_x * x + 2.0 * p1 * y + 6.0 * p2 * x
808
+ fx_y = d_y * x + 2.0 * p1 * x + 2.0 * p2 * y
809
+
810
+ # Compute derivative of fy over x and y.
811
+ fy_x = d_x * y + 2.0 * p2 * y + 2.0 * p1 * x
812
+ fy_y = d + d_y * y + 2.0 * p2 * x + 6.0 * p1 * y
813
+
814
+ return fx, fy, fx_x, fx_y, fy_x, fy_y
815
+
816
+
817
+ # @torch_compile(dynamic=True, mode="reduce-overhead", backend="eager")
818
+ def radial_and_tangential_undistort(
819
+ coords: torch.Tensor,
820
+ distortion_params: torch.Tensor,
821
+ eps: float = 1e-3,
822
+ max_iterations: int = 10,
823
+ ) -> torch.Tensor:
824
+ """Computes undistorted coords given opencv distortion parameters.
825
+ Adapted from MultiNeRF
826
+ https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/camera_utils.py#L477-L509
827
+
828
+ Args:
829
+ coords: The distorted coordinates.
830
+ distortion_params: The distortion parameters [k1, k2, k3, k4, p1, p2].
831
+ eps: The epsilon for the convergence.
832
+ max_iterations: The maximum number of iterations to perform.
833
+
834
+ Returns:
835
+ The undistorted coordinates.
836
+ """
837
+
838
+ # Initialize from the distorted point.
839
+ x = coords[..., 0]
840
+ y = coords[..., 1]
841
+
842
+ for _ in range(max_iterations):
843
+ fx, fy, fx_x, fx_y, fy_x, fy_y = _compute_residual_and_jacobian(
844
+ x=x, y=y, xd=coords[..., 0], yd=coords[..., 1], distortion_params=distortion_params
845
+ )
846
+ denominator = fy_x * fx_y - fx_x * fy_y
847
+ x_numerator = fx * fy_y - fy * fx_y
848
+ y_numerator = fy * fx_x - fx * fy_x
849
+ step_x = torch.where(torch.abs(denominator) > eps, x_numerator / denominator, torch.zeros_like(denominator))
850
+ step_y = torch.where(torch.abs(denominator) > eps, y_numerator / denominator, torch.zeros_like(denominator))
851
+
852
+ x = x + step_x
853
+ y = y + step_y
854
+
855
+ return torch.stack([x, y], dim=-1)
856
+
857
+
858
+ def rotation_matrix(a: Float[Tensor, "3"], b: Float[Tensor, "3"]) -> Float[Tensor, "3 3"]:
859
+ """Compute the rotation matrix that rotates vector a to vector b.
860
+
861
+ Args:
862
+ a: The vector to rotate.
863
+ b: The vector to rotate to.
864
+ Returns:
865
+ The rotation matrix.
866
+ """
867
+ a = a / torch.linalg.norm(a)
868
+ b = b / torch.linalg.norm(b)
869
+ v = torch.cross(a, b)
870
+ c = torch.dot(a, b)
871
+ # If vectors are exactly opposite, we add a little noise to one of them
872
+ if c < -1 + 1e-8:
873
+ eps = (torch.rand(3) - 0.5) * 0.01
874
+ return rotation_matrix(a + eps, b)
875
+ s = torch.linalg.norm(v)
876
+ skew_sym_mat = torch.Tensor(
877
+ [
878
+ [0, -v[2], v[1]],
879
+ [v[2], 0, -v[0]],
880
+ [-v[1], v[0], 0],
881
+ ]
882
+ )
883
+ return torch.eye(3) + skew_sym_mat + skew_sym_mat @ skew_sym_mat * ((1 - c) / (s**2 + 1e-8))
884
+
885
+
886
+ def focus_of_attention(poses: Float[Tensor, "*num_poses 4 4"], initial_focus: Float[Tensor, "3"]) -> Float[Tensor, "3"]:
887
+ """Compute the focus of attention of a set of cameras. Only cameras
888
+ that have the focus of attention in front of them are considered.
889
+
890
+ Args:
891
+ poses: The poses to orient.
892
+ initial_focus: The 3D point views to decide which cameras are initially activated.
893
+
894
+ Returns:
895
+ The 3D position of the focus of attention.
896
+ """
897
+ # References to the same method in third-party code:
898
+ # https://github.com/google-research/multinerf/blob/1c8b1c552133cdb2de1c1f3c871b2813f6662265/internal/camera_utils.py#L145
899
+ # https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/load_llff.py#L197
900
+ active_directions = -poses[:, :3, 2:3]
901
+ active_origins = poses[:, :3, 3:4]
902
+ # initial value for testing if the focus_pt is in front or behind
903
+ focus_pt = initial_focus
904
+ # Prune cameras which have the current have the focus_pt behind them.
905
+ active = torch.sum(active_directions.squeeze(-1) * (focus_pt - active_origins.squeeze(-1)), dim=-1) > 0
906
+ done = False
907
+ # We need at least two active cameras, else fallback on the previous solution.
908
+ # This may be the "poses" solution if no cameras are active on first iteration, e.g.
909
+ # they are in an outward-looking configuration.
910
+ while torch.sum(active.int()) > 1 and not done:
911
+ active_directions = active_directions[active]
912
+ active_origins = active_origins[active]
913
+ # https://en.wikipedia.org/wiki/Line–line_intersection#In_more_than_two_dimensions
914
+ m = torch.eye(3) - active_directions * torch.transpose(active_directions, -2, -1)
915
+ mt_m = torch.transpose(m, -2, -1) @ m
916
+ focus_pt = torch.linalg.inv(mt_m.mean(0)) @ (mt_m @ active_origins).mean(0)[:, 0]
917
+ active = torch.sum(active_directions.squeeze(-1) * (focus_pt - active_origins.squeeze(-1)), dim=-1) > 0
918
+ if active.all():
919
+ # the set of active cameras did not change, so we're done.
920
+ done = True
921
+ return focus_pt
922
+
923
+
924
+ def auto_orient_and_center_poses(
925
+ poses: Float[Tensor, "*num_poses 4 4"],
926
+ method: Literal["pca", "up", "vertical", "none"] = "up",
927
+ center_method: Literal["poses", "focus", "none"] = "poses",
928
+ ) -> Tuple[Float[Tensor, "*num_poses 3 4"], Float[Tensor, "3 4"]]:
929
+ """Orients and centers the poses.
930
+
931
+ We provide three methods for orientation:
932
+
933
+ - pca: Orient the poses so that the principal directions of the camera centers are aligned
934
+ with the axes, Z corresponding to the smallest principal component.
935
+ This method works well when all of the cameras are in the same plane, for example when
936
+ images are taken using a mobile robot.
937
+ - up: Orient the poses so that the average up vector is aligned with the z axis.
938
+ This method works well when images are not at arbitrary angles.
939
+ - vertical: Orient the poses so that the Z 3D direction projects close to the
940
+ y axis in images. This method works better if cameras are not all
941
+ looking in the same 3D direction, which may happen in camera arrays or in LLFF.
942
+
943
+ There are two centering methods:
944
+
945
+ - poses: The poses are centered around the origin.
946
+ - focus: The origin is set to the focus of attention of all cameras (the
947
+ closest point to cameras optical axes). Recommended for inward-looking
948
+ camera configurations.
949
+
950
+ Args:
951
+ poses: The poses to orient.
952
+ method: The method to use for orientation.
953
+ center_method: The method to use to center the poses.
954
+
955
+ Returns:
956
+ Tuple of the oriented poses and the transform matrix.
957
+ """
958
+
959
+ origins = poses[..., :3, 3]
960
+
961
+ mean_origin = torch.mean(origins, dim=0)
962
+ translation_diff = origins - mean_origin
963
+
964
+ if center_method == "poses":
965
+ translation = mean_origin
966
+ elif center_method == "focus":
967
+ translation = focus_of_attention(poses, mean_origin)
968
+ elif center_method == "none":
969
+ translation = torch.zeros_like(mean_origin)
970
+ else:
971
+ raise ValueError(f"Unknown value for center_method: {center_method}")
972
+
973
+ if method == "pca":
974
+ _, eigvec = torch.linalg.eigh(translation_diff.T @ translation_diff)
975
+ eigvec = torch.flip(eigvec, dims=(-1,))
976
+
977
+ if torch.linalg.det(eigvec) < 0:
978
+ eigvec[:, 2] = -eigvec[:, 2]
979
+
980
+ transform = torch.cat([eigvec, eigvec @ -translation[..., None]], dim=-1)
981
+ oriented_poses = transform @ poses
982
+
983
+ if oriented_poses.mean(dim=0)[2, 1] < 0:
984
+ oriented_poses[:, 1:3] = -1 * oriented_poses[:, 1:3]
985
+ elif method in ("up", "vertical"):
986
+ up = torch.mean(poses[:, :3, 1], dim=0)
987
+ up = up / torch.linalg.norm(up)
988
+ if method == "vertical":
989
+ # If cameras are not all parallel (e.g. not in an LLFF configuration),
990
+ # we can find the 3D direction that most projects vertically in all
991
+ # cameras by minimizing ||Xu|| s.t. ||u||=1. This total least squares
992
+ # problem is solved by SVD.
993
+ x_axis_matrix = poses[:, :3, 0]
994
+ _, S, Vh = torch.linalg.svd(x_axis_matrix, full_matrices=False)
995
+ # Singular values are S_i=||Xv_i|| for each right singular vector v_i.
996
+ # ||S|| = sqrt(n) because lines of X are all unit vectors and the v_i
997
+ # are an orthonormal basis.
998
+ # ||Xv_i|| = sqrt(sum(dot(x_axis_j,v_i)^2)), thus S_i/sqrt(n) is the
999
+ # RMS of cosines between x axes and v_i. If the second smallest singular
1000
+ # value corresponds to an angle error less than 10° (cos(80°)=0.17),
1001
+ # this is probably a degenerate camera configuration (typical values
1002
+ # are around 5° average error for the true vertical). In this case,
1003
+ # rather than taking the vector corresponding to the smallest singular
1004
+ # value, we project the "up" vector on the plane spanned by the two
1005
+ # best singular vectors. We could also just fallback to the "up"
1006
+ # solution.
1007
+ if S[1] > 0.17 * math.sqrt(poses.shape[0]):
1008
+ # regular non-degenerate configuration
1009
+ up_vertical = Vh[2, :]
1010
+ # It may be pointing up or down. Use "up" to disambiguate the sign.
1011
+ up = up_vertical if torch.dot(up_vertical, up) > 0 else -up_vertical
1012
+ else:
1013
+ # Degenerate configuration: project "up" on the plane spanned by
1014
+ # the last two right singular vectors (which are orthogonal to the
1015
+ # first). v_0 is a unit vector, no need to divide by its norm when
1016
+ # projecting.
1017
+ up = up - Vh[0, :] * torch.dot(up, Vh[0, :])
1018
+ # re-normalize
1019
+ up = up / torch.linalg.norm(up)
1020
+
1021
+ rotation = rotation_matrix(up, torch.Tensor([0, 0, 1]))
1022
+ transform = torch.cat([rotation, rotation @ -translation[..., None]], dim=-1)
1023
+ oriented_poses = transform @ poses
1024
+ elif method == "none":
1025
+ transform = torch.eye(4)
1026
+ transform[:3, 3] = -translation
1027
+ transform = transform[:3, :]
1028
+ oriented_poses = transform @ poses
1029
+ else:
1030
+ raise ValueError(f"Unknown value for method: {method}")
1031
+
1032
+ return oriented_poses, transform
1033
+
1034
+
1035
+ @torch.jit.script
1036
+ def fisheye624_project(xyz, params):
1037
+ """
1038
+ Batched implementation of the FisheyeRadTanThinPrism (aka Fisheye624) camera
1039
+ model project() function.
1040
+ Inputs:
1041
+ xyz: BxNx3 tensor of 3D points to be projected
1042
+ params: Bx16 tensor of Fisheye624 parameters formatted like this:
1043
+ [f_u f_v c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}]
1044
+ or Bx15 tensor of Fisheye624 parameters formatted like this:
1045
+ [f c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}]
1046
+ Outputs:
1047
+ uv: BxNx2 tensor of 2D projections of xyz in image plane
1048
+ Model for fisheye cameras with radial, tangential, and thin-prism distortion.
1049
+ This model allows fu != fv.
1050
+ Specifically, the model is:
1051
+ uvDistorted = [x_r] + tangentialDistortion + thinPrismDistortion
1052
+ [y_r]
1053
+ proj = diag(fu,fv) * uvDistorted + [cu;cv];
1054
+ where:
1055
+ a = x/z, b = y/z, r = (a^2+b^2)^(1/2)
1056
+ th = atan(r)
1057
+ cosPhi = a/r, sinPhi = b/r
1058
+ [x_r] = (th+ k0 * th^3 + k1* th^5 + ...) [cosPhi]
1059
+ [y_r] [sinPhi]
1060
+ the number of terms in the series is determined by the template parameter numK.
1061
+ tangentialDistortion = [(2 x_r^2 + rd^2)*p_0 + 2*x_r*y_r*p_1]
1062
+ [(2 y_r^2 + rd^2)*p_1 + 2*x_r*y_r*p_0]
1063
+ where rd^2 = x_r^2 + y_r^2
1064
+ thinPrismDistortion = [s0 * rd^2 + s1 rd^4]
1065
+ [s2 * rd^2 + s3 rd^4]
1066
+ Author: Daniel DeTone ([email protected])
1067
+ """
1068
+
1069
+ assert xyz.ndim == 3
1070
+ assert params.ndim == 2
1071
+ assert params.shape[-1] == 16 or params.shape[-1] == 15, "This model allows fx != fy"
1072
+ eps = 1e-9
1073
+ B, N = xyz.shape[0], xyz.shape[1]
1074
+
1075
+ # Radial correction.
1076
+ z = xyz[:, :, 2].reshape(B, N, 1)
1077
+ z = torch.where(torch.abs(z) < eps, eps * torch.sign(z), z)
1078
+ ab = xyz[:, :, :2] / z
1079
+ r = torch.norm(ab, dim=-1, p=2, keepdim=True)
1080
+ th = torch.atan(r)
1081
+ th_divr = torch.where(r < eps, torch.ones_like(ab), ab / r)
1082
+ th_k = th.reshape(B, N, 1).clone()
1083
+ for i in range(6):
1084
+ th_k = th_k + params[:, -12 + i].reshape(B, 1, 1) * torch.pow(th, 3 + i * 2)
1085
+ xr_yr = th_k * th_divr
1086
+ uv_dist = xr_yr
1087
+
1088
+ # Tangential correction.
1089
+ p0 = params[:, -6].reshape(B, 1)
1090
+ p1 = params[:, -5].reshape(B, 1)
1091
+ xr = xr_yr[:, :, 0].reshape(B, N)
1092
+ yr = xr_yr[:, :, 1].reshape(B, N)
1093
+ xr_yr_sq = torch.square(xr_yr)
1094
+ xr_sq = xr_yr_sq[:, :, 0].reshape(B, N)
1095
+ yr_sq = xr_yr_sq[:, :, 1].reshape(B, N)
1096
+ rd_sq = xr_sq + yr_sq
1097
+ uv_dist_tu = uv_dist[:, :, 0] + ((2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1)
1098
+ uv_dist_tv = uv_dist[:, :, 1] + ((2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0)
1099
+ uv_dist = torch.stack([uv_dist_tu, uv_dist_tv], dim=-1) # Avoids in-place complaint.
1100
+
1101
+ # Thin Prism correction.
1102
+ s0 = params[:, -4].reshape(B, 1)
1103
+ s1 = params[:, -3].reshape(B, 1)
1104
+ s2 = params[:, -2].reshape(B, 1)
1105
+ s3 = params[:, -1].reshape(B, 1)
1106
+ rd_4 = torch.square(rd_sq)
1107
+ uv_dist[:, :, 0] = uv_dist[:, :, 0] + (s0 * rd_sq + s1 * rd_4)
1108
+ uv_dist[:, :, 1] = uv_dist[:, :, 1] + (s2 * rd_sq + s3 * rd_4)
1109
+
1110
+ # Finally, apply standard terms: focal length and camera centers.
1111
+ if params.shape[-1] == 15:
1112
+ fx_fy = params[:, 0].reshape(B, 1, 1)
1113
+ cx_cy = params[:, 1:3].reshape(B, 1, 2)
1114
+ else:
1115
+ fx_fy = params[:, 0:2].reshape(B, 1, 2)
1116
+ cx_cy = params[:, 2:4].reshape(B, 1, 2)
1117
+ result = uv_dist * fx_fy + cx_cy
1118
+
1119
+ return result
1120
+
1121
+
1122
+ # Core implementation of fisheye 624 unprojection. More details are documented here:
1123
+ # https://facebookresearch.github.io/projectaria_tools/docs/tech_insights/camera_intrinsic_models#the-fisheye62-model
1124
+ @torch.jit.script
1125
+ def fisheye624_unproject_helper(uv, params, max_iters: int = 5):
1126
+ """
1127
+ Batched implementation of the FisheyeRadTanThinPrism (aka Fisheye624) camera
1128
+ model. There is no analytical solution for the inverse of the project()
1129
+ function so this solves an optimization problem using Newton's method to get
1130
+ the inverse.
1131
+ Inputs:
1132
+ uv: BxNx2 tensor of 2D pixels to be unprojected
1133
+ params: Bx16 tensor of Fisheye624 parameters formatted like this:
1134
+ [f_u f_v c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}]
1135
+ or Bx15 tensor of Fisheye624 parameters formatted like this:
1136
+ [f c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}]
1137
+ Outputs:
1138
+ xyz: BxNx3 tensor of 3D rays of uv points with z = 1.
1139
+ Model for fisheye cameras with radial, tangential, and thin-prism distortion.
1140
+ This model assumes fu=fv. This unproject function holds that:
1141
+ X = unproject(project(X)) [for X=(x,y,z) in R^3, z>0]
1142
+ and
1143
+ x = project(unproject(s*x)) [for s!=0 and x=(u,v) in R^2]
1144
+ Author: Daniel DeTone ([email protected])
1145
+ """
1146
+
1147
+ assert uv.ndim == 3, "Expected batched input shaped BxNx3"
1148
+ assert params.ndim == 2
1149
+ assert params.shape[-1] == 16 or params.shape[-1] == 15, "This model allows fx != fy"
1150
+ eps = 1e-6
1151
+ B, N = uv.shape[0], uv.shape[1]
1152
+
1153
+ if params.shape[-1] == 15:
1154
+ fx_fy = params[:, 0].reshape(B, 1, 1)
1155
+ cx_cy = params[:, 1:3].reshape(B, 1, 2)
1156
+ else:
1157
+ fx_fy = params[:, 0:2].reshape(B, 1, 2)
1158
+ cx_cy = params[:, 2:4].reshape(B, 1, 2)
1159
+
1160
+ uv_dist = (uv - cx_cy) / fx_fy
1161
+
1162
+ # Compute xr_yr using Newton's method.
1163
+ xr_yr = uv_dist.clone() # Initial guess.
1164
+ for _ in range(max_iters):
1165
+ uv_dist_est = xr_yr.clone()
1166
+ # Tangential terms.
1167
+ p0 = params[:, -6].reshape(B, 1)
1168
+ p1 = params[:, -5].reshape(B, 1)
1169
+ xr = xr_yr[:, :, 0].reshape(B, N)
1170
+ yr = xr_yr[:, :, 1].reshape(B, N)
1171
+ xr_yr_sq = torch.square(xr_yr)
1172
+ xr_sq = xr_yr_sq[:, :, 0].reshape(B, N)
1173
+ yr_sq = xr_yr_sq[:, :, 1].reshape(B, N)
1174
+ rd_sq = xr_sq + yr_sq
1175
+ uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + ((2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1)
1176
+ uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + ((2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0)
1177
+ # Thin Prism terms.
1178
+ s0 = params[:, -4].reshape(B, 1)
1179
+ s1 = params[:, -3].reshape(B, 1)
1180
+ s2 = params[:, -2].reshape(B, 1)
1181
+ s3 = params[:, -1].reshape(B, 1)
1182
+ rd_4 = torch.square(rd_sq)
1183
+ uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + (s0 * rd_sq + s1 * rd_4)
1184
+ uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + (s2 * rd_sq + s3 * rd_4)
1185
+ # Compute the derivative of uv_dist w.r.t. xr_yr.
1186
+ duv_dist_dxr_yr = uv.new_ones(B, N, 2, 2)
1187
+ duv_dist_dxr_yr[:, :, 0, 0] = 1.0 + 6.0 * xr_yr[:, :, 0] * p0 + 2.0 * xr_yr[:, :, 1] * p1
1188
+ offdiag = 2.0 * (xr_yr[:, :, 0] * p1 + xr_yr[:, :, 1] * p0)
1189
+ duv_dist_dxr_yr[:, :, 0, 1] = offdiag
1190
+ duv_dist_dxr_yr[:, :, 1, 0] = offdiag
1191
+ duv_dist_dxr_yr[:, :, 1, 1] = 1.0 + 6.0 * xr_yr[:, :, 1] * p1 + 2.0 * xr_yr[:, :, 0] * p0
1192
+ xr_yr_sq_norm = xr_yr_sq[:, :, 0] + xr_yr_sq[:, :, 1]
1193
+ temp1 = 2.0 * (s0 + 2.0 * s1 * xr_yr_sq_norm)
1194
+ duv_dist_dxr_yr[:, :, 0, 0] = duv_dist_dxr_yr[:, :, 0, 0] + (xr_yr[:, :, 0] * temp1)
1195
+ duv_dist_dxr_yr[:, :, 0, 1] = duv_dist_dxr_yr[:, :, 0, 1] + (xr_yr[:, :, 1] * temp1)
1196
+ temp2 = 2.0 * (s2 + 2.0 * s3 * xr_yr_sq_norm)
1197
+ duv_dist_dxr_yr[:, :, 1, 0] = duv_dist_dxr_yr[:, :, 1, 0] + (xr_yr[:, :, 0] * temp2)
1198
+ duv_dist_dxr_yr[:, :, 1, 1] = duv_dist_dxr_yr[:, :, 1, 1] + (xr_yr[:, :, 1] * temp2)
1199
+ # Compute 2x2 inverse manually here since torch.inverse() is very slow.
1200
+ # Because this is slow: inv = duv_dist_dxr_yr.inverse()
1201
+ # About a 10x reduction in speed with above line.
1202
+ mat = duv_dist_dxr_yr.reshape(-1, 2, 2)
1203
+ a = mat[:, 0, 0].reshape(-1, 1, 1)
1204
+ b = mat[:, 0, 1].reshape(-1, 1, 1)
1205
+ c = mat[:, 1, 0].reshape(-1, 1, 1)
1206
+ d = mat[:, 1, 1].reshape(-1, 1, 1)
1207
+ det = 1.0 / ((a * d) - (b * c))
1208
+ top = torch.cat([d, -b], dim=2)
1209
+ bot = torch.cat([-c, a], dim=2)
1210
+ inv = det * torch.cat([top, bot], dim=1)
1211
+ inv = inv.reshape(B, N, 2, 2)
1212
+ # Manually compute 2x2 @ 2x1 matrix multiply.
1213
+ # Because this is slow: step = (inv @ (uv_dist - uv_dist_est)[..., None])[..., 0]
1214
+ diff = uv_dist - uv_dist_est
1215
+ a = inv[:, :, 0, 0]
1216
+ b = inv[:, :, 0, 1]
1217
+ c = inv[:, :, 1, 0]
1218
+ d = inv[:, :, 1, 1]
1219
+ e = diff[:, :, 0]
1220
+ f = diff[:, :, 1]
1221
+ step = torch.stack([a * e + b * f, c * e + d * f], dim=-1)
1222
+ # Newton step.
1223
+ xr_yr = xr_yr + step
1224
+
1225
+ # Compute theta using Newton's method.
1226
+ xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1)
1227
+ th = xr_yr_norm.clone()
1228
+ for _ in range(max_iters):
1229
+ th_radial = uv.new_ones(B, N, 1)
1230
+ dthd_th = uv.new_ones(B, N, 1)
1231
+ for k in range(6):
1232
+ r_k = params[:, -12 + k].reshape(B, 1, 1)
1233
+ th_radial = th_radial + (r_k * torch.pow(th, 2 + k * 2))
1234
+ dthd_th = dthd_th + ((3.0 + 2.0 * k) * r_k * torch.pow(th, 2 + k * 2))
1235
+ th_radial = th_radial * th
1236
+ step = (xr_yr_norm - th_radial) / dthd_th
1237
+ # handle dthd_th close to 0.
1238
+ step = torch.where(dthd_th.abs() > eps, step, torch.sign(step) * eps * 10.0)
1239
+ th = th + step
1240
+ # Compute the ray direction using theta and xr_yr.
1241
+ close_to_zero = torch.logical_and(th.abs() < eps, xr_yr_norm.abs() < eps)
1242
+ ray_dir = torch.where(close_to_zero, xr_yr, torch.tan(th) / xr_yr_norm * xr_yr)
1243
+ ray = torch.cat([ray_dir, uv.new_ones(B, N, 1)], dim=2)
1244
+ return ray
1245
+
1246
+
1247
+ # unproject 2D point to 3D with fisheye624 model
1248
+ def fisheye624_unproject(coords: torch.Tensor, distortion_params: torch.Tensor) -> torch.Tensor:
1249
+ dirs = fisheye624_unproject_helper(coords.unsqueeze(0), distortion_params[0].unsqueeze(0))
1250
+ # correct for camera space differences:
1251
+ dirs[..., 1] = -dirs[..., 1]
1252
+ dirs[..., 2] = -dirs[..., 2]
1253
+ return dirs
sgm/data/cifar10.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torchvision
3
+ from torch.utils.data import DataLoader, Dataset
4
+ from torchvision import transforms
5
+
6
+
7
+ class CIFAR10DataDictWrapper(Dataset):
8
+ def __init__(self, dset):
9
+ super().__init__()
10
+ self.dset = dset
11
+
12
+ def __getitem__(self, i):
13
+ x, y = self.dset[i]
14
+ return {"jpg": x, "cls": y}
15
+
16
+ def __len__(self):
17
+ return len(self.dset)
18
+
19
+
20
+ class CIFAR10Loader(pl.LightningDataModule):
21
+ def __init__(self, batch_size, num_workers=0, shuffle=True):
22
+ super().__init__()
23
+
24
+ transform = transforms.Compose(
25
+ [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
26
+ )
27
+
28
+ self.batch_size = batch_size
29
+ self.num_workers = num_workers
30
+ self.shuffle = shuffle
31
+ self.train_dataset = CIFAR10DataDictWrapper(
32
+ torchvision.datasets.CIFAR10(
33
+ root=".data/", train=True, download=True, transform=transform
34
+ )
35
+ )
36
+ self.test_dataset = CIFAR10DataDictWrapper(
37
+ torchvision.datasets.CIFAR10(
38
+ root=".data/", train=False, download=True, transform=transform
39
+ )
40
+ )
41
+
42
+ def prepare_data(self):
43
+ pass
44
+
45
+ def train_dataloader(self):
46
+ return DataLoader(
47
+ self.train_dataset,
48
+ batch_size=self.batch_size,
49
+ shuffle=self.shuffle,
50
+ num_workers=self.num_workers,
51
+ )
52
+
53
+ def test_dataloader(self):
54
+ return DataLoader(
55
+ self.test_dataset,
56
+ batch_size=self.batch_size,
57
+ shuffle=self.shuffle,
58
+ num_workers=self.num_workers,
59
+ )
60
+
61
+ def val_dataloader(self):
62
+ return DataLoader(
63
+ self.test_dataset,
64
+ batch_size=self.batch_size,
65
+ shuffle=self.shuffle,
66
+ num_workers=self.num_workers,
67
+ )
sgm/data/co3d.py ADDED
@@ -0,0 +1,1367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ adopted from SparseFusion
3
+ Wrapper for the full CO3Dv2 dataset
4
+ #@ Modified from https://github.com/facebookresearch/pytorch3d
5
+ """
6
+
7
+ import json
8
+ import logging
9
+ import math
10
+ import os
11
+ import random
12
+ import time
13
+ import warnings
14
+ from collections import defaultdict
15
+ from itertools import islice
16
+ from typing import (
17
+ Any,
18
+ ClassVar,
19
+ List,
20
+ Mapping,
21
+ Optional,
22
+ Sequence,
23
+ Tuple,
24
+ Type,
25
+ TypedDict,
26
+ Union,
27
+ )
28
+ from einops import rearrange, repeat
29
+
30
+ import numpy as np
31
+ import torch
32
+ import torch.nn.functional as F
33
+ import torchvision.transforms.functional as TF
34
+ from pytorch3d.utils import opencv_from_cameras_projection
35
+ from pytorch3d.implicitron.dataset import types
36
+ from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
37
+ from sgm.data.json_index_dataset import (
38
+ FrameAnnotsEntry,
39
+ _bbox_xywh_to_xyxy,
40
+ _bbox_xyxy_to_xywh,
41
+ _clamp_box_to_image_bounds_and_round,
42
+ _crop_around_box,
43
+ _get_1d_bounds,
44
+ _get_bbox_from_mask,
45
+ _get_clamp_bbox,
46
+ _load_1bit_png_mask,
47
+ _load_16big_png_depth,
48
+ _load_depth,
49
+ _load_depth_mask,
50
+ _load_image,
51
+ _load_mask,
52
+ _load_pointcloud,
53
+ _rescale_bbox,
54
+ _safe_as_tensor,
55
+ _seq_name_to_seed,
56
+ )
57
+ from sgm.data.objaverse import video_collate_fn
58
+ from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import (
59
+ get_available_subset_names,
60
+ )
61
+ from pytorch3d.renderer.cameras import PerspectiveCameras
62
+
63
+ logger = logging.getLogger(__name__)
64
+
65
+
66
+ from dataclasses import dataclass, field, fields
67
+
68
+ from pytorch3d.renderer.camera_utils import join_cameras_as_batch
69
+ from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
70
+ from pytorch3d.structures.pointclouds import Pointclouds, join_pointclouds_as_batch
71
+ from pytorch_lightning import LightningDataModule
72
+ from torch.utils.data import DataLoader
73
+
74
+ CO3D_ALL_CATEGORIES = list(
75
+ reversed(
76
+ [
77
+ "baseballbat",
78
+ "banana",
79
+ "bicycle",
80
+ "microwave",
81
+ "tv",
82
+ "cellphone",
83
+ "toilet",
84
+ "hairdryer",
85
+ "couch",
86
+ "kite",
87
+ "pizza",
88
+ "umbrella",
89
+ "wineglass",
90
+ "laptop",
91
+ "hotdog",
92
+ "stopsign",
93
+ "frisbee",
94
+ "baseballglove",
95
+ "cup",
96
+ "parkingmeter",
97
+ "backpack",
98
+ "toyplane",
99
+ "toybus",
100
+ "handbag",
101
+ "chair",
102
+ "keyboard",
103
+ "car",
104
+ "motorcycle",
105
+ "carrot",
106
+ "bottle",
107
+ "sandwich",
108
+ "remote",
109
+ "bowl",
110
+ "skateboard",
111
+ "toaster",
112
+ "mouse",
113
+ "toytrain",
114
+ "book",
115
+ "toytruck",
116
+ "orange",
117
+ "broccoli",
118
+ "plant",
119
+ "teddybear",
120
+ "suitcase",
121
+ "bench",
122
+ "ball",
123
+ "cake",
124
+ "vase",
125
+ "hydrant",
126
+ "apple",
127
+ "donut",
128
+ ]
129
+ )
130
+ )
131
+
132
+ CO3D_ALL_TEN = [
133
+ "donut",
134
+ "apple",
135
+ "hydrant",
136
+ "vase",
137
+ "cake",
138
+ "ball",
139
+ "bench",
140
+ "suitcase",
141
+ "teddybear",
142
+ "plant",
143
+ ]
144
+
145
+
146
+ # @ FROM https://github.com/facebookresearch/pytorch3d
147
+ @dataclass
148
+ class FrameData(Mapping[str, Any]):
149
+ """
150
+ A type of the elements returned by indexing the dataset object.
151
+ It can represent both individual frames and batches of thereof;
152
+ in this documentation, the sizes of tensors refer to single frames;
153
+ add the first batch dimension for the collation result.
154
+ Args:
155
+ frame_number: The number of the frame within its sequence.
156
+ 0-based continuous integers.
157
+ sequence_name: The unique name of the frame's sequence.
158
+ sequence_category: The object category of the sequence.
159
+ frame_timestamp: The time elapsed since the start of a sequence in sec.
160
+ image_size_hw: The size of the image in pixels; (height, width) tensor
161
+ of shape (2,).
162
+ image_path: The qualified path to the loaded image (with dataset_root).
163
+ image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
164
+ of the frame; elements are floats in [0, 1].
165
+ mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
166
+ regions. Regions can be invalid (mask_crop[i,j]=0) in case they
167
+ are a result of zero-padding of the image after cropping around
168
+ the object bounding box; elements are floats in {0.0, 1.0}.
169
+ depth_path: The qualified path to the frame's depth map.
170
+ depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
171
+ of the frame; values correspond to distances from the camera;
172
+ use `depth_mask` and `mask_crop` to filter for valid pixels.
173
+ depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
174
+ depth map that are valid for evaluation, they have been checked for
175
+ consistency across views; elements are floats in {0.0, 1.0}.
176
+ mask_path: A qualified path to the foreground probability mask.
177
+ fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
178
+ pixels belonging to the captured object; elements are floats
179
+ in [0, 1].
180
+ bbox_xywh: The bounding box tightly enclosing the foreground object in the
181
+ format (x0, y0, width, height). The convention assumes that
182
+ `x0+width` and `y0+height` includes the boundary of the box.
183
+ I.e., to slice out the corresponding crop from an image tensor `I`
184
+ we execute `crop = I[..., y0:y0+height, x0:x0+width]`
185
+ crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb`
186
+ in the original image coordinates in the format (x0, y0, width, height).
187
+ The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs
188
+ from `bbox_xywh` due to padding (which can happen e.g. due to
189
+ setting `JsonIndexDataset.box_crop_context > 0`)
190
+ camera: A PyTorch3D camera object corresponding the frame's viewpoint,
191
+ corrected for cropping if it happened.
192
+ camera_quality_score: The score proportional to the confidence of the
193
+ frame's camera estimation (the higher the more accurate).
194
+ point_cloud_quality_score: The score proportional to the accuracy of the
195
+ frame's sequence point cloud (the higher the more accurate).
196
+ sequence_point_cloud_path: The path to the sequence's point cloud.
197
+ sequence_point_cloud: A PyTorch3D Pointclouds object holding the
198
+ point cloud corresponding to the frame's sequence. When the object
199
+ represents a batch of frames, point clouds may be deduplicated;
200
+ see `sequence_point_cloud_idx`.
201
+ sequence_point_cloud_idx: Integer indices mapping frame indices to the
202
+ corresponding point clouds in `sequence_point_cloud`; to get the
203
+ corresponding point cloud to `image_rgb[i]`, use
204
+ `sequence_point_cloud[sequence_point_cloud_idx[i]]`.
205
+ frame_type: The type of the loaded frame specified in
206
+ `subset_lists_file`, if provided.
207
+ meta: A dict for storing additional frame information.
208
+ """
209
+
210
+ frame_number: Optional[torch.LongTensor]
211
+ sequence_name: Union[str, List[str]]
212
+ sequence_category: Union[str, List[str]]
213
+ frame_timestamp: Optional[torch.Tensor] = None
214
+ image_size_hw: Optional[torch.Tensor] = None
215
+ image_path: Union[str, List[str], None] = None
216
+ image_rgb: Optional[torch.Tensor] = None
217
+ # masks out padding added due to cropping the square bit
218
+ mask_crop: Optional[torch.Tensor] = None
219
+ depth_path: Union[str, List[str], None] = ""
220
+ depth_map: Optional[torch.Tensor] = torch.zeros(1)
221
+ depth_mask: Optional[torch.Tensor] = torch.zeros(1)
222
+ mask_path: Union[str, List[str], None] = None
223
+ fg_probability: Optional[torch.Tensor] = None
224
+ bbox_xywh: Optional[torch.Tensor] = None
225
+ crop_bbox_xywh: Optional[torch.Tensor] = None
226
+ camera: Optional[PerspectiveCameras] = None
227
+ camera_quality_score: Optional[torch.Tensor] = None
228
+ point_cloud_quality_score: Optional[torch.Tensor] = None
229
+ sequence_point_cloud_path: Union[str, List[str], None] = ""
230
+ sequence_point_cloud: Optional[Pointclouds] = torch.zeros(1)
231
+ sequence_point_cloud_idx: Optional[torch.Tensor] = torch.zeros(1)
232
+ frame_type: Union[str, List[str], None] = "" # known | unseen
233
+ meta: dict = field(default_factory=lambda: {})
234
+ valid_region: Optional[torch.Tensor] = None
235
+ category_one_hot: Optional[torch.Tensor] = None
236
+
237
+ def to(self, *args, **kwargs):
238
+ new_params = {}
239
+ for f in fields(self):
240
+ value = getattr(self, f.name)
241
+ if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
242
+ new_params[f.name] = value.to(*args, **kwargs)
243
+ else:
244
+ new_params[f.name] = value
245
+ return type(self)(**new_params)
246
+
247
+ def cpu(self):
248
+ return self.to(device=torch.device("cpu"))
249
+
250
+ def cuda(self):
251
+ return self.to(device=torch.device("cuda"))
252
+
253
+ # the following functions make sure **frame_data can be passed to functions
254
+ def __iter__(self):
255
+ for f in fields(self):
256
+ yield f.name
257
+
258
+ def __getitem__(self, key):
259
+ return getattr(self, key)
260
+
261
+ def __len__(self):
262
+ return len(fields(self))
263
+
264
+ @classmethod
265
+ def collate(cls, batch):
266
+ """
267
+ Given a list objects `batch` of class `cls`, collates them into a batched
268
+ representation suitable for processing with deep networks.
269
+ """
270
+
271
+ elem = batch[0]
272
+
273
+ if isinstance(elem, cls):
274
+ pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
275
+ id_to_idx = defaultdict(list)
276
+ for i, pc_id in enumerate(pointcloud_ids):
277
+ id_to_idx[pc_id].append(i)
278
+
279
+ sequence_point_cloud = []
280
+ sequence_point_cloud_idx = -np.ones((len(batch),))
281
+ for i, ind in enumerate(id_to_idx.values()):
282
+ sequence_point_cloud_idx[ind] = i
283
+ sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
284
+ assert (sequence_point_cloud_idx >= 0).all()
285
+
286
+ override_fields = {
287
+ "sequence_point_cloud": sequence_point_cloud,
288
+ "sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
289
+ }
290
+ # note that the pre-collate value of sequence_point_cloud_idx is unused
291
+
292
+ collated = {}
293
+ for f in fields(elem):
294
+ list_values = override_fields.get(
295
+ f.name, [getattr(d, f.name) for d in batch]
296
+ )
297
+ collated[f.name] = (
298
+ cls.collate(list_values)
299
+ if all(list_value is not None for list_value in list_values)
300
+ else None
301
+ )
302
+ return cls(**collated)
303
+
304
+ elif isinstance(elem, Pointclouds):
305
+ return join_pointclouds_as_batch(batch)
306
+
307
+ elif isinstance(elem, CamerasBase):
308
+ # TODO: don't store K; enforce working in NDC space
309
+ return join_cameras_as_batch(batch)
310
+ else:
311
+ return torch.utils.data._utils.collate.default_collate(batch)
312
+
313
+
314
+ # @ MODIFIED FROM https://github.com/facebookresearch/pytorch3d
315
+ class CO3Dv2Wrapper(torch.utils.data.Dataset):
316
+ def __init__(
317
+ self,
318
+ root_dir="/drive/datasets/co3d/",
319
+ category="hydrant",
320
+ subset="fewview_train",
321
+ stage="train",
322
+ sample_batch_size=20,
323
+ image_size=256,
324
+ masked=False,
325
+ deprecated_val_region=False,
326
+ return_frame_data_list=False,
327
+ reso: int = 256,
328
+ mask_type: str = "random",
329
+ cond_aug_mean=-3.0,
330
+ cond_aug_std=0.5,
331
+ condition_on_elevation=False,
332
+ fps_id=0.0,
333
+ motion_bucket_id=300.0,
334
+ num_frames: int = 20,
335
+ use_mask: bool = True,
336
+ load_pixelnerf: bool = True,
337
+ scale_pose: bool = True,
338
+ max_n_cond: int = 5,
339
+ min_n_cond: int = 2,
340
+ cond_on_multi: bool = False,
341
+ ):
342
+ root = root_dir
343
+ from typing import List
344
+
345
+ from co3d.dataset.data_types import (
346
+ FrameAnnotation,
347
+ SequenceAnnotation,
348
+ load_dataclass_jgzip,
349
+ )
350
+
351
+ self.dataset_root = root
352
+ self.path_manager = None
353
+ self.subset = subset
354
+ self.stage = stage
355
+ self.subset_lists_file: List[str] = [
356
+ f"{self.dataset_root}/{category}/set_lists/set_lists_{subset}.json"
357
+ ]
358
+ self.subsets: Optional[List[str]] = [subset]
359
+ self.sample_batch_size = sample_batch_size
360
+ self.limit_to: int = 0
361
+ self.limit_sequences_to: int = 0
362
+ self.pick_sequence: Tuple[str, ...] = ()
363
+ self.exclude_sequence: Tuple[str, ...] = ()
364
+ self.limit_category_to: Tuple[int, ...] = ()
365
+ self.load_images: bool = True
366
+ self.load_depths: bool = False
367
+ self.load_depth_masks: bool = False
368
+ self.load_masks: bool = True
369
+ self.load_point_clouds: bool = False
370
+ self.max_points: int = 0
371
+ self.mask_images: bool = False
372
+ self.mask_depths: bool = False
373
+ self.image_height: Optional[int] = image_size
374
+ self.image_width: Optional[int] = image_size
375
+ self.box_crop: bool = True
376
+ self.box_crop_mask_thr: float = 0.4
377
+ self.box_crop_context: float = 0.3
378
+ self.remove_empty_masks: bool = True
379
+ self.n_frames_per_sequence: int = -1
380
+ self.seed: int = 0
381
+ self.sort_frames: bool = False
382
+ self.eval_batches: Any = None
383
+
384
+ self.img_h = self.image_height
385
+ self.img_w = self.image_width
386
+ self.masked = masked
387
+ self.deprecated_val_region = deprecated_val_region
388
+ self.return_frame_data_list = return_frame_data_list
389
+
390
+ self.reso = reso
391
+ self.num_frames = num_frames
392
+ self.cond_aug_mean = cond_aug_mean
393
+ self.cond_aug_std = cond_aug_std
394
+ self.condition_on_elevation = condition_on_elevation
395
+ self.fps_id = fps_id
396
+ self.motion_bucket_id = motion_bucket_id
397
+ self.mask_type = mask_type
398
+ self.use_mask = use_mask
399
+ self.load_pixelnerf = load_pixelnerf
400
+ self.scale_pose = scale_pose
401
+ self.max_n_cond = max_n_cond
402
+ self.min_n_cond = min_n_cond
403
+ self.cond_on_multi = cond_on_multi
404
+
405
+ if self.cond_on_multi:
406
+ assert self.min_n_cond == self.max_n_cond
407
+
408
+ start_time = time.time()
409
+ if "all_" in category or category == "all":
410
+ self.category_frame_annotations = []
411
+ self.category_sequence_annotations = []
412
+ self.subset_lists_file = []
413
+
414
+ if category == "all":
415
+ cats = CO3D_ALL_CATEGORIES
416
+ elif category == "all_four":
417
+ cats = ["hydrant", "teddybear", "motorcycle", "bench"]
418
+ elif category == "all_ten":
419
+ cats = [
420
+ "donut",
421
+ "apple",
422
+ "hydrant",
423
+ "vase",
424
+ "cake",
425
+ "ball",
426
+ "bench",
427
+ "suitcase",
428
+ "teddybear",
429
+ "plant",
430
+ ]
431
+ elif category == "all_15":
432
+ cats = [
433
+ "hydrant",
434
+ "teddybear",
435
+ "motorcycle",
436
+ "bench",
437
+ "hotdog",
438
+ "remote",
439
+ "suitcase",
440
+ "donut",
441
+ "plant",
442
+ "toaster",
443
+ "keyboard",
444
+ "handbag",
445
+ "toyplane",
446
+ "tv",
447
+ "orange",
448
+ ]
449
+ else:
450
+ print("UNSPECIFIED CATEGORY SUBSET")
451
+ cats = ["hydrant", "teddybear"]
452
+ print("loading", cats)
453
+ for cat in cats:
454
+ self.category_frame_annotations.extend(
455
+ load_dataclass_jgzip(
456
+ f"{self.dataset_root}/{cat}/frame_annotations.jgz",
457
+ List[FrameAnnotation],
458
+ )
459
+ )
460
+ self.category_sequence_annotations.extend(
461
+ load_dataclass_jgzip(
462
+ f"{self.dataset_root}/{cat}/sequence_annotations.jgz",
463
+ List[SequenceAnnotation],
464
+ )
465
+ )
466
+ self.subset_lists_file.append(
467
+ f"{self.dataset_root}/{cat}/set_lists/set_lists_{subset}.json"
468
+ )
469
+
470
+ else:
471
+ self.category_frame_annotations = load_dataclass_jgzip(
472
+ f"{self.dataset_root}/{category}/frame_annotations.jgz",
473
+ List[FrameAnnotation],
474
+ )
475
+ self.category_sequence_annotations = load_dataclass_jgzip(
476
+ f"{self.dataset_root}/{category}/sequence_annotations.jgz",
477
+ List[SequenceAnnotation],
478
+ )
479
+
480
+ self.subset_to_image_path = None
481
+ self._load_frames()
482
+ self._load_sequences()
483
+ self._sort_frames()
484
+ self._load_subset_lists()
485
+ self._filter_db() # also computes sequence indices
486
+ # self._extract_and_set_eval_batches()
487
+ # print(self.eval_batches)
488
+ logger.info(str(self))
489
+
490
+ self.seq_to_frames = {}
491
+ for fi, item in enumerate(self.frame_annots):
492
+ if item["frame_annotation"].sequence_name in self.seq_to_frames:
493
+ self.seq_to_frames[item["frame_annotation"].sequence_name].append(fi)
494
+ else:
495
+ self.seq_to_frames[item["frame_annotation"].sequence_name] = [fi]
496
+
497
+ if self.stage != "test" or self.subset != "fewview_test":
498
+ count = 0
499
+ new_seq_to_frames = {}
500
+ for item in self.seq_to_frames:
501
+ if len(self.seq_to_frames[item]) > 10:
502
+ count += 1
503
+ new_seq_to_frames[item] = self.seq_to_frames[item]
504
+ self.seq_to_frames = new_seq_to_frames
505
+
506
+ self.seq_list = list(self.seq_to_frames.keys())
507
+
508
+ # @ REMOVE A FEW TRAINING SEQ THAT CAUSES BUG
509
+ remove_list = ["411_55952_107659", "376_42884_85882"]
510
+ for remove_idx in remove_list:
511
+ if remove_idx in self.seq_to_frames:
512
+ self.seq_list.remove(remove_idx)
513
+ print("removing", remove_idx)
514
+
515
+ print("total training seq", len(self.seq_to_frames))
516
+ print("data loading took", time.time() - start_time, "seconds")
517
+
518
+ self.all_category_list = list(CO3D_ALL_CATEGORIES)
519
+ self.all_category_list.sort()
520
+ self.cat_to_idx = {}
521
+ for ci, cname in enumerate(self.all_category_list):
522
+ self.cat_to_idx[cname] = ci
523
+
524
+ def __len__(self):
525
+ return len(self.seq_list)
526
+
527
+ def __getitem__(self, index):
528
+ seq_index = self.seq_list[index]
529
+
530
+ if self.subset == "fewview_test" and self.stage == "test":
531
+ batch_idx = torch.arange(len(self.seq_to_frames[seq_index]))
532
+
533
+ elif self.stage == "test":
534
+ batch_idx = (
535
+ torch.linspace(
536
+ 0, len(self.seq_to_frames[seq_index]) - 1, self.sample_batch_size
537
+ )
538
+ .long()
539
+ .tolist()
540
+ )
541
+ else:
542
+ rand = torch.randperm(len(self.seq_to_frames[seq_index]))
543
+ batch_idx = rand[: min(len(rand), self.sample_batch_size)]
544
+
545
+ frame_data_list = []
546
+ idx_list = []
547
+ timestamp_list = []
548
+ for idx in batch_idx:
549
+ idx_list.append(self.seq_to_frames[seq_index][idx])
550
+ timestamp_list.append(
551
+ self.frame_annots[self.seq_to_frames[seq_index][idx]][
552
+ "frame_annotation"
553
+ ].frame_timestamp
554
+ )
555
+ frame_data_list.append(
556
+ self._get_frame(int(self.seq_to_frames[seq_index][idx]))
557
+ )
558
+
559
+ time_order = torch.argsort(torch.tensor(timestamp_list))
560
+ frame_data_list = [frame_data_list[i] for i in time_order]
561
+
562
+ frame_data = FrameData.collate(frame_data_list)
563
+ image_size = torch.Tensor([self.image_height]).repeat(
564
+ frame_data.camera.R.shape[0], 2
565
+ )
566
+ frame_dict = {
567
+ "R": frame_data.camera.R,
568
+ "T": frame_data.camera.T,
569
+ "f": frame_data.camera.focal_length,
570
+ "c": frame_data.camera.principal_point,
571
+ "images": frame_data.image_rgb * frame_data.fg_probability
572
+ + (1 - frame_data.fg_probability),
573
+ "valid_region": frame_data.mask_crop,
574
+ "bbox": frame_data.valid_region,
575
+ "image_size": image_size,
576
+ "frame_type": frame_data.frame_type,
577
+ "idx": seq_index,
578
+ "category": frame_data.category_one_hot,
579
+ }
580
+ if not self.masked:
581
+ frame_dict["images_full"] = frame_data.image_rgb
582
+ frame_dict["masks"] = frame_data.fg_probability
583
+ frame_dict["mask_crop"] = frame_data.mask_crop
584
+
585
+ cond_aug = np.exp(
586
+ np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
587
+ )
588
+
589
+ def _pad(input):
590
+ return torch.cat([input, torch.flip(input, dims=[0])], dim=0)[
591
+ : self.num_frames
592
+ ]
593
+
594
+ if len(frame_dict["images"]) < self.num_frames:
595
+ for k in frame_dict:
596
+ if isinstance(frame_dict[k], torch.Tensor):
597
+ frame_dict[k] = _pad(frame_dict[k])
598
+
599
+ data = dict()
600
+ if "images_full" in frame_dict:
601
+ frames = frame_dict["images_full"] * 2 - 1
602
+ else:
603
+ frames = frame_dict["images"] * 2 - 1
604
+ data["frames"] = frames
605
+ cond = frames[0]
606
+ data["cond_frames_without_noise"] = cond
607
+ data["cond_aug"] = torch.as_tensor([cond_aug] * self.num_frames)
608
+ data["cond_frames"] = cond + cond_aug * torch.randn_like(cond)
609
+ data["fps_id"] = torch.as_tensor([self.fps_id] * self.num_frames)
610
+ data["motion_bucket_id"] = torch.as_tensor(
611
+ [self.motion_bucket_id] * self.num_frames
612
+ )
613
+ data["num_video_frames"] = self.num_frames
614
+ data["image_only_indicator"] = torch.as_tensor([0.0] * self.num_frames)
615
+
616
+ if self.load_pixelnerf:
617
+ data["pixelnerf_input"] = dict()
618
+ # Rs = frame_dict["R"].transpose(-1, -2)
619
+ # Ts = frame_dict["T"]
620
+ # Rs[:, :, 2] *= -1
621
+ # Rs[:, :, 0] *= -1
622
+ # Ts[:, 2] *= -1
623
+ # Ts[:, 0] *= -1
624
+ # c2ws = torch.zeros(Rs.shape[0], 4, 4)
625
+ # c2ws[:, :3, :3] = Rs
626
+ # c2ws[:, :3, 3] = Ts
627
+ # c2ws[:, 3, 3] = 1
628
+ # c2ws = c2ws.inverse()
629
+ # # c2ws[..., 0] *= -1
630
+ # # c2ws[..., 2] *= -1
631
+ # cx = frame_dict["c"][:, 0]
632
+ # cy = frame_dict["c"][:, 1]
633
+ # fx = frame_dict["f"][:, 0]
634
+ # fy = frame_dict["f"][:, 1]
635
+ # intrinsics = torch.zeros(cx.shape[0], 3, 3)
636
+ # intrinsics[:, 2, 2] = 1
637
+ # intrinsics[:, 0, 0] = fx
638
+ # intrinsics[:, 1, 1] = fy
639
+ # intrinsics[:, 0, 2] = cx
640
+ # intrinsics[:, 1, 2] = cy
641
+
642
+ scene_cameras = PerspectiveCameras(
643
+ R=frame_dict["R"],
644
+ T=frame_dict["T"],
645
+ focal_length=frame_dict["f"],
646
+ principal_point=frame_dict["c"],
647
+ image_size=frame_dict["image_size"],
648
+ )
649
+ R, T, intrinsics = opencv_from_cameras_projection(
650
+ scene_cameras, frame_dict["image_size"]
651
+ )
652
+ c2ws = torch.zeros(R.shape[0], 4, 4)
653
+ c2ws[:, :3, :3] = R
654
+ c2ws[:, :3, 3] = T
655
+ c2ws[:, 3, 3] = 1.0
656
+ c2ws = c2ws.inverse()
657
+ c2ws[..., 1:3] *= -1
658
+ intrinsics[:, :2] /= 256
659
+
660
+ cameras = torch.zeros(c2ws.shape[0], 25)
661
+ cameras[..., :16] = c2ws.reshape(-1, 16)
662
+ cameras[..., 16:] = intrinsics.reshape(-1, 9)
663
+ if self.scale_pose:
664
+ c2ws = cameras[..., :16].reshape(-1, 4, 4)
665
+ center = c2ws[:, :3, 3].mean(0)
666
+ radius = (c2ws[:, :3, 3] - center).norm(dim=-1).max()
667
+ scale = 1.5 / radius
668
+ c2ws[..., :3, 3] = (c2ws[..., :3, 3] - center) * scale
669
+ cameras[..., :16] = c2ws.reshape(-1, 16)
670
+
671
+ data["pixelnerf_input"]["frames"] = frames
672
+ data["pixelnerf_input"]["cameras"] = cameras
673
+ data["pixelnerf_input"]["rgb"] = (
674
+ F.interpolate(
675
+ frames,
676
+ (self.image_width // 8, self.image_height // 8),
677
+ mode="bilinear",
678
+ align_corners=False,
679
+ )
680
+ + 1
681
+ ) * 0.5
682
+
683
+ return data
684
+ # if self.return_frame_data_list:
685
+ # return (frame_dict, frame_data_list)
686
+ # return frame_dict
687
+
688
+ def collate_fn(self, batch):
689
+ # a hack to add source index and keep consistent within a batch
690
+ if self.max_n_cond > 1:
691
+ # TODO implement this
692
+ n_cond = np.random.randint(self.min_n_cond, self.max_n_cond + 1)
693
+ # debug
694
+ # source_index = [0]
695
+ if n_cond > 1:
696
+ for b in batch:
697
+ source_index = [0] + np.random.choice(
698
+ np.arange(1, self.num_frames),
699
+ self.max_n_cond - 1,
700
+ replace=False,
701
+ ).tolist()
702
+ b["pixelnerf_input"]["source_index"] = torch.as_tensor(source_index)
703
+ b["pixelnerf_input"]["n_cond"] = n_cond
704
+ b["pixelnerf_input"]["source_images"] = b["frames"][source_index]
705
+ b["pixelnerf_input"]["source_cameras"] = b["pixelnerf_input"][
706
+ "cameras"
707
+ ][source_index]
708
+
709
+ if self.cond_on_multi:
710
+ b["cond_frames_without_noise"] = b["frames"][source_index]
711
+
712
+ ret = video_collate_fn(batch)
713
+
714
+ if self.cond_on_multi:
715
+ ret["cond_frames_without_noise"] = rearrange(
716
+ ret["cond_frames_without_noise"], "b t ... -> (b t) ..."
717
+ )
718
+
719
+ return ret
720
+
721
+ def _get_frame(self, index):
722
+ # if index >= len(self.frame_annots):
723
+ # raise IndexError(f"index {index} out of range {len(self.frame_annots)}")
724
+
725
+ entry = self.frame_annots[index]["frame_annotation"]
726
+ # pyre-ignore[16]
727
+ point_cloud = self.seq_annots[entry.sequence_name].point_cloud
728
+ frame_data = FrameData(
729
+ frame_number=_safe_as_tensor(entry.frame_number, torch.long),
730
+ frame_timestamp=_safe_as_tensor(entry.frame_timestamp, torch.float),
731
+ sequence_name=entry.sequence_name,
732
+ sequence_category=self.seq_annots[entry.sequence_name].category,
733
+ camera_quality_score=_safe_as_tensor(
734
+ self.seq_annots[entry.sequence_name].viewpoint_quality_score,
735
+ torch.float,
736
+ ),
737
+ point_cloud_quality_score=_safe_as_tensor(
738
+ point_cloud.quality_score, torch.float
739
+ )
740
+ if point_cloud is not None
741
+ else None,
742
+ )
743
+
744
+ # The rest of the fields are optional
745
+ frame_data.frame_type = self._get_frame_type(self.frame_annots[index])
746
+
747
+ (
748
+ frame_data.fg_probability,
749
+ frame_data.mask_path,
750
+ frame_data.bbox_xywh,
751
+ clamp_bbox_xyxy,
752
+ frame_data.crop_bbox_xywh,
753
+ ) = self._load_crop_fg_probability(entry)
754
+
755
+ scale = 1.0
756
+ if self.load_images and entry.image is not None:
757
+ # original image size
758
+ frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long)
759
+
760
+ (
761
+ frame_data.image_rgb,
762
+ frame_data.image_path,
763
+ frame_data.mask_crop,
764
+ scale,
765
+ ) = self._load_crop_images(
766
+ entry, frame_data.fg_probability, clamp_bbox_xyxy
767
+ )
768
+ # print(frame_data.fg_probability.sum())
769
+ # print('scale', scale)
770
+
771
+ #! INSERT
772
+ if self.deprecated_val_region:
773
+ # print(frame_data.crop_bbox_xywh)
774
+ valid_bbox = _bbox_xywh_to_xyxy(frame_data.crop_bbox_xywh).float()
775
+ # print(valid_bbox, frame_data.image_size_hw)
776
+ valid_bbox[0] = torch.clip(
777
+ (
778
+ valid_bbox[0]
779
+ - torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor")
780
+ )
781
+ / torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor"),
782
+ -1.0,
783
+ 1.0,
784
+ )
785
+ valid_bbox[1] = torch.clip(
786
+ (
787
+ valid_bbox[1]
788
+ - torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor")
789
+ )
790
+ / torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor"),
791
+ -1.0,
792
+ 1.0,
793
+ )
794
+ valid_bbox[2] = torch.clip(
795
+ (
796
+ valid_bbox[2]
797
+ - torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor")
798
+ )
799
+ / torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor"),
800
+ -1.0,
801
+ 1.0,
802
+ )
803
+ valid_bbox[3] = torch.clip(
804
+ (
805
+ valid_bbox[3]
806
+ - torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor")
807
+ )
808
+ / torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor"),
809
+ -1.0,
810
+ 1.0,
811
+ )
812
+ # print(valid_bbox)
813
+ frame_data.valid_region = valid_bbox
814
+ else:
815
+ #! UPDATED VALID BBOX
816
+ if self.stage == "train":
817
+ assert self.image_height == 256 and self.image_width == 256
818
+ valid = torch.nonzero(frame_data.mask_crop[0])
819
+ min_y = valid[:, 0].min()
820
+ min_x = valid[:, 1].min()
821
+ max_y = valid[:, 0].max()
822
+ max_x = valid[:, 1].max()
823
+ valid_bbox = torch.tensor(
824
+ [min_y, min_x, max_y, max_x], device=frame_data.image_rgb.device
825
+ ).unsqueeze(0)
826
+ valid_bbox = torch.clip(
827
+ (valid_bbox - (256 // 2)) / (256 // 2), -1.0, 1.0
828
+ )
829
+ frame_data.valid_region = valid_bbox[0]
830
+ else:
831
+ valid = torch.nonzero(frame_data.mask_crop[0])
832
+ min_y = valid[:, 0].min()
833
+ min_x = valid[:, 1].min()
834
+ max_y = valid[:, 0].max()
835
+ max_x = valid[:, 1].max()
836
+ valid_bbox = torch.tensor(
837
+ [min_y, min_x, max_y, max_x], device=frame_data.image_rgb.device
838
+ ).unsqueeze(0)
839
+ valid_bbox = torch.clip(
840
+ (valid_bbox - (self.image_height // 2)) / (self.image_height // 2),
841
+ -1.0,
842
+ 1.0,
843
+ )
844
+ frame_data.valid_region = valid_bbox[0]
845
+
846
+ #! SET CLASS ONEHOT
847
+ frame_data.category_one_hot = torch.zeros(
848
+ (len(self.all_category_list)), device=frame_data.image_rgb.device
849
+ )
850
+ frame_data.category_one_hot[self.cat_to_idx[frame_data.sequence_category]] = 1
851
+
852
+ if self.load_depths and entry.depth is not None:
853
+ (
854
+ frame_data.depth_map,
855
+ frame_data.depth_path,
856
+ frame_data.depth_mask,
857
+ ) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability)
858
+
859
+ if entry.viewpoint is not None:
860
+ frame_data.camera = self._get_pytorch3d_camera(
861
+ entry,
862
+ scale,
863
+ clamp_bbox_xyxy,
864
+ )
865
+
866
+ if self.load_point_clouds and point_cloud is not None:
867
+ frame_data.sequence_point_cloud_path = pcl_path = os.path.join(
868
+ self.dataset_root, point_cloud.path
869
+ )
870
+ frame_data.sequence_point_cloud = _load_pointcloud(
871
+ self._local_path(pcl_path), max_points=self.max_points
872
+ )
873
+
874
+ # for key in frame_data:
875
+ # if frame_data[key] == None:
876
+ # print(key)
877
+ return frame_data
878
+
879
+ def _extract_and_set_eval_batches(self):
880
+ """
881
+ Sets eval_batches based on input eval_batch_index.
882
+ """
883
+ if self.eval_batch_index is not None:
884
+ if self.eval_batches is not None:
885
+ raise ValueError(
886
+ "Cannot define both eval_batch_index and eval_batches."
887
+ )
888
+ self.eval_batches = self.seq_frame_index_to_dataset_index(
889
+ self.eval_batch_index
890
+ )
891
+
892
+ def _load_crop_fg_probability(
893
+ self, entry: types.FrameAnnotation
894
+ ) -> Tuple[
895
+ Optional[torch.Tensor],
896
+ Optional[str],
897
+ Optional[torch.Tensor],
898
+ Optional[torch.Tensor],
899
+ Optional[torch.Tensor],
900
+ ]:
901
+ fg_probability = None
902
+ full_path = None
903
+ bbox_xywh = None
904
+ clamp_bbox_xyxy = None
905
+ crop_box_xywh = None
906
+
907
+ if (self.load_masks or self.box_crop) and entry.mask is not None:
908
+ full_path = os.path.join(self.dataset_root, entry.mask.path)
909
+ mask = _load_mask(self._local_path(full_path))
910
+
911
+ if mask.shape[-2:] != entry.image.size:
912
+ raise ValueError(
913
+ f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!"
914
+ )
915
+
916
+ bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr))
917
+
918
+ if self.box_crop:
919
+ clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round(
920
+ _get_clamp_bbox(
921
+ bbox_xywh,
922
+ image_path=entry.image.path,
923
+ box_crop_context=self.box_crop_context,
924
+ ),
925
+ image_size_hw=tuple(mask.shape[-2:]),
926
+ )
927
+ crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy)
928
+
929
+ mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path)
930
+
931
+ fg_probability, _, _ = self._resize_image(mask, mode="nearest")
932
+
933
+ return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh
934
+
935
+ def _load_crop_images(
936
+ self,
937
+ entry: types.FrameAnnotation,
938
+ fg_probability: Optional[torch.Tensor],
939
+ clamp_bbox_xyxy: Optional[torch.Tensor],
940
+ ) -> Tuple[torch.Tensor, str, torch.Tensor, float]:
941
+ assert self.dataset_root is not None and entry.image is not None
942
+ path = os.path.join(self.dataset_root, entry.image.path)
943
+ image_rgb = _load_image(self._local_path(path))
944
+
945
+ if image_rgb.shape[-2:] != entry.image.size:
946
+ raise ValueError(
947
+ f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
948
+ )
949
+
950
+ if self.box_crop:
951
+ assert clamp_bbox_xyxy is not None
952
+ image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path)
953
+
954
+ image_rgb, scale, mask_crop = self._resize_image(image_rgb)
955
+
956
+ if self.mask_images:
957
+ assert fg_probability is not None
958
+ image_rgb *= fg_probability
959
+
960
+ return image_rgb, path, mask_crop, scale
961
+
962
+ def _load_mask_depth(
963
+ self,
964
+ entry: types.FrameAnnotation,
965
+ clamp_bbox_xyxy: Optional[torch.Tensor],
966
+ fg_probability: Optional[torch.Tensor],
967
+ ) -> Tuple[torch.Tensor, str, torch.Tensor]:
968
+ entry_depth = entry.depth
969
+ assert entry_depth is not None
970
+ path = os.path.join(self.dataset_root, entry_depth.path)
971
+ depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment)
972
+
973
+ if self.box_crop:
974
+ assert clamp_bbox_xyxy is not None
975
+ depth_bbox_xyxy = _rescale_bbox(
976
+ clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:]
977
+ )
978
+ depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path)
979
+
980
+ depth_map, _, _ = self._resize_image(depth_map, mode="nearest")
981
+
982
+ if self.mask_depths:
983
+ assert fg_probability is not None
984
+ depth_map *= fg_probability
985
+
986
+ if self.load_depth_masks:
987
+ assert entry_depth.mask_path is not None
988
+ mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
989
+ depth_mask = _load_depth_mask(self._local_path(mask_path))
990
+
991
+ if self.box_crop:
992
+ assert clamp_bbox_xyxy is not None
993
+ depth_mask_bbox_xyxy = _rescale_bbox(
994
+ clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:]
995
+ )
996
+ depth_mask = _crop_around_box(
997
+ depth_mask, depth_mask_bbox_xyxy, mask_path
998
+ )
999
+
1000
+ depth_mask, _, _ = self._resize_image(depth_mask, mode="nearest")
1001
+ else:
1002
+ depth_mask = torch.ones_like(depth_map)
1003
+
1004
+ return depth_map, path, depth_mask
1005
+
1006
+ def _get_pytorch3d_camera(
1007
+ self,
1008
+ entry: types.FrameAnnotation,
1009
+ scale: float,
1010
+ clamp_bbox_xyxy: Optional[torch.Tensor],
1011
+ ) -> PerspectiveCameras:
1012
+ entry_viewpoint = entry.viewpoint
1013
+ assert entry_viewpoint is not None
1014
+ # principal point and focal length
1015
+ principal_point = torch.tensor(
1016
+ entry_viewpoint.principal_point, dtype=torch.float
1017
+ )
1018
+ focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)
1019
+
1020
+ half_image_size_wh_orig = (
1021
+ torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0
1022
+ )
1023
+
1024
+ # first, we convert from the dataset's NDC convention to pixels
1025
+ format = entry_viewpoint.intrinsics_format
1026
+ if format.lower() == "ndc_norm_image_bounds":
1027
+ # this is e.g. currently used in CO3D for storing intrinsics
1028
+ rescale = half_image_size_wh_orig
1029
+ elif format.lower() == "ndc_isotropic":
1030
+ rescale = half_image_size_wh_orig.min()
1031
+ else:
1032
+ raise ValueError(f"Unknown intrinsics format: {format}")
1033
+
1034
+ # principal point and focal length in pixels
1035
+ principal_point_px = half_image_size_wh_orig - principal_point * rescale
1036
+ focal_length_px = focal_length * rescale
1037
+ if self.box_crop:
1038
+ assert clamp_bbox_xyxy is not None
1039
+ principal_point_px -= clamp_bbox_xyxy[:2]
1040
+
1041
+ # now, convert from pixels to PyTorch3D v0.5+ NDC convention
1042
+ if self.image_height is None or self.image_width is None:
1043
+ out_size = list(reversed(entry.image.size))
1044
+ else:
1045
+ out_size = [self.image_width, self.image_height]
1046
+
1047
+ half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0
1048
+ half_min_image_size_output = half_image_size_output.min()
1049
+
1050
+ # rescaled principal point and focal length in ndc
1051
+ principal_point = (
1052
+ half_image_size_output - principal_point_px * scale
1053
+ ) / half_min_image_size_output
1054
+ focal_length = focal_length_px * scale / half_min_image_size_output
1055
+
1056
+ return PerspectiveCameras(
1057
+ focal_length=focal_length[None],
1058
+ principal_point=principal_point[None],
1059
+ R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],
1060
+ T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
1061
+ )
1062
+
1063
+ def _load_frames(self) -> None:
1064
+ self.frame_annots = [
1065
+ FrameAnnotsEntry(frame_annotation=a, subset=None)
1066
+ for a in self.category_frame_annotations
1067
+ ]
1068
+
1069
+ def _load_sequences(self) -> None:
1070
+ self.seq_annots = {
1071
+ entry.sequence_name: entry for entry in self.category_sequence_annotations
1072
+ }
1073
+
1074
+ def _load_subset_lists(self) -> None:
1075
+ logger.info(f"Loading Co3D subset lists from {self.subset_lists_file}.")
1076
+ if not self.subset_lists_file:
1077
+ return
1078
+
1079
+ frame_path_to_subset = {}
1080
+
1081
+ for subset_list_file in self.subset_lists_file:
1082
+ with open(self._local_path(subset_list_file), "r") as f:
1083
+ subset_to_seq_frame = json.load(f)
1084
+
1085
+ #! PRINT SUBSET_LIST STATS
1086
+ # if len(self.subset_lists_file) == 1:
1087
+ # print('train frames', len(subset_to_seq_frame['train']))
1088
+ # print('val frames', len(subset_to_seq_frame['val']))
1089
+ # print('test frames', len(subset_to_seq_frame['test']))
1090
+
1091
+ for set_ in subset_to_seq_frame:
1092
+ for _, _, path in subset_to_seq_frame[set_]:
1093
+ if path in frame_path_to_subset:
1094
+ frame_path_to_subset[path].add(set_)
1095
+ else:
1096
+ frame_path_to_subset[path] = {set_}
1097
+
1098
+ # pyre-ignore[16]
1099
+ for frame in self.frame_annots:
1100
+ frame["subset"] = frame_path_to_subset.get(
1101
+ frame["frame_annotation"].image.path, None
1102
+ )
1103
+
1104
+ if frame["subset"] is None:
1105
+ continue
1106
+ warnings.warn(
1107
+ "Subset lists are given but don't include "
1108
+ + frame["frame_annotation"].image.path
1109
+ )
1110
+
1111
+ def _sort_frames(self) -> None:
1112
+ # Sort frames to have them grouped by sequence, ordered by timestamp
1113
+ # pyre-ignore[16]
1114
+ self.frame_annots = sorted(
1115
+ self.frame_annots,
1116
+ key=lambda f: (
1117
+ f["frame_annotation"].sequence_name,
1118
+ f["frame_annotation"].frame_timestamp or 0,
1119
+ ),
1120
+ )
1121
+
1122
+ def _filter_db(self) -> None:
1123
+ if self.remove_empty_masks:
1124
+ logger.info("Removing images with empty masks.")
1125
+ # pyre-ignore[16]
1126
+ old_len = len(self.frame_annots)
1127
+
1128
+ msg = "remove_empty_masks needs every MaskAnnotation.mass to be set."
1129
+
1130
+ def positive_mass(frame_annot: types.FrameAnnotation) -> bool:
1131
+ mask = frame_annot.mask
1132
+ if mask is None:
1133
+ return False
1134
+ if mask.mass is None:
1135
+ raise ValueError(msg)
1136
+ return mask.mass > 1
1137
+
1138
+ self.frame_annots = [
1139
+ frame
1140
+ for frame in self.frame_annots
1141
+ if positive_mass(frame["frame_annotation"])
1142
+ ]
1143
+ logger.info("... filtered %d -> %d" % (old_len, len(self.frame_annots)))
1144
+
1145
+ # this has to be called after joining with categories!!
1146
+ subsets = self.subsets
1147
+ if subsets:
1148
+ if not self.subset_lists_file:
1149
+ raise ValueError(
1150
+ "Subset filter is on but subset_lists_file was not given"
1151
+ )
1152
+
1153
+ logger.info(f"Limiting Co3D dataset to the '{subsets}' subsets.")
1154
+
1155
+ # truncate the list of subsets to the valid one
1156
+ self.frame_annots = [
1157
+ entry
1158
+ for entry in self.frame_annots
1159
+ if (entry["subset"] is not None and self.stage in entry["subset"])
1160
+ ]
1161
+
1162
+ if len(self.frame_annots) == 0:
1163
+ raise ValueError(f"There are no frames in the '{subsets}' subsets!")
1164
+
1165
+ self._invalidate_indexes(filter_seq_annots=True)
1166
+
1167
+ if len(self.limit_category_to) > 0:
1168
+ logger.info(f"Limiting dataset to categories: {self.limit_category_to}")
1169
+ # pyre-ignore[16]
1170
+ self.seq_annots = {
1171
+ name: entry
1172
+ for name, entry in self.seq_annots.items()
1173
+ if entry.category in self.limit_category_to
1174
+ }
1175
+
1176
+ # sequence filters
1177
+ for prefix in ("pick", "exclude"):
1178
+ orig_len = len(self.seq_annots)
1179
+ attr = f"{prefix}_sequence"
1180
+ arr = getattr(self, attr)
1181
+ if len(arr) > 0:
1182
+ logger.info(f"{attr}: {str(arr)}")
1183
+ self.seq_annots = {
1184
+ name: entry
1185
+ for name, entry in self.seq_annots.items()
1186
+ if (name in arr) == (prefix == "pick")
1187
+ }
1188
+ logger.info("... filtered %d -> %d" % (orig_len, len(self.seq_annots)))
1189
+
1190
+ if self.limit_sequences_to > 0:
1191
+ self.seq_annots = dict(
1192
+ islice(self.seq_annots.items(), self.limit_sequences_to)
1193
+ )
1194
+
1195
+ # retain only frames from retained sequences
1196
+ self.frame_annots = [
1197
+ f
1198
+ for f in self.frame_annots
1199
+ if f["frame_annotation"].sequence_name in self.seq_annots
1200
+ ]
1201
+
1202
+ self._invalidate_indexes()
1203
+
1204
+ if self.n_frames_per_sequence > 0:
1205
+ logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.")
1206
+ keep_idx = []
1207
+ # pyre-ignore[16]
1208
+ for seq, seq_indices in self._seq_to_idx.items():
1209
+ # infer the seed from the sequence name, this is reproducible
1210
+ # and makes the selection differ for different sequences
1211
+ seed = _seq_name_to_seed(seq) + self.seed
1212
+ seq_idx_shuffled = random.Random(seed).sample(
1213
+ sorted(seq_indices), len(seq_indices)
1214
+ )
1215
+ keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence])
1216
+
1217
+ logger.info(
1218
+ "... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx))
1219
+ )
1220
+ self.frame_annots = [self.frame_annots[i] for i in keep_idx]
1221
+ self._invalidate_indexes(filter_seq_annots=False)
1222
+ # sequences are not decimated, so self.seq_annots is valid
1223
+
1224
+ if self.limit_to > 0 and self.limit_to < len(self.frame_annots):
1225
+ logger.info(
1226
+ "limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to)
1227
+ )
1228
+ self.frame_annots = self.frame_annots[: self.limit_to]
1229
+ self._invalidate_indexes(filter_seq_annots=True)
1230
+
1231
+ def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None:
1232
+ # update _seq_to_idx and filter seq_meta according to frame_annots change
1233
+ # if filter_seq_annots, also uldates seq_annots based on the changed _seq_to_idx
1234
+ self._invalidate_seq_to_idx()
1235
+
1236
+ if filter_seq_annots:
1237
+ # pyre-ignore[16]
1238
+ self.seq_annots = {
1239
+ k: v
1240
+ for k, v in self.seq_annots.items()
1241
+ # pyre-ignore[16]
1242
+ if k in self._seq_to_idx
1243
+ }
1244
+
1245
+ def _invalidate_seq_to_idx(self) -> None:
1246
+ seq_to_idx = defaultdict(list)
1247
+ # pyre-ignore[16]
1248
+ for idx, entry in enumerate(self.frame_annots):
1249
+ seq_to_idx[entry["frame_annotation"].sequence_name].append(idx)
1250
+ # pyre-ignore[16]
1251
+ self._seq_to_idx = seq_to_idx
1252
+
1253
+ def _resize_image(
1254
+ self, image, mode="bilinear"
1255
+ ) -> Tuple[torch.Tensor, float, torch.Tensor]:
1256
+ image_height, image_width = self.image_height, self.image_width
1257
+ if image_height is None or image_width is None:
1258
+ # skip the resizing
1259
+ imre_ = torch.from_numpy(image)
1260
+ return imre_, 1.0, torch.ones_like(imre_[:1])
1261
+ # takes numpy array, returns pytorch tensor
1262
+ minscale = min(
1263
+ image_height / image.shape[-2],
1264
+ image_width / image.shape[-1],
1265
+ )
1266
+ imre = torch.nn.functional.interpolate(
1267
+ torch.from_numpy(image)[None],
1268
+ scale_factor=minscale,
1269
+ mode=mode,
1270
+ align_corners=False if mode == "bilinear" else None,
1271
+ recompute_scale_factor=True,
1272
+ )[0]
1273
+ # pyre-fixme[19]: Expected 1 positional argument.
1274
+ imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width)
1275
+ imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
1276
+ # pyre-fixme[6]: For 2nd param expected `int` but got `Optional[int]`.
1277
+ # pyre-fixme[6]: For 3rd param expected `int` but got `Optional[int]`.
1278
+ mask = torch.zeros(1, self.image_height, self.image_width)
1279
+ mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0
1280
+ return imre_, minscale, mask
1281
+
1282
+ def _local_path(self, path: str) -> str:
1283
+ if self.path_manager is None:
1284
+ return path
1285
+ return self.path_manager.get_local_path(path)
1286
+
1287
+ def get_frame_numbers_and_timestamps(
1288
+ self, idxs: Sequence[int]
1289
+ ) -> List[Tuple[int, float]]:
1290
+ out: List[Tuple[int, float]] = []
1291
+ for idx in idxs:
1292
+ # pyre-ignore[16]
1293
+ frame_annotation = self.frame_annots[idx]["frame_annotation"]
1294
+ out.append(
1295
+ (frame_annotation.frame_number, frame_annotation.frame_timestamp)
1296
+ )
1297
+ return out
1298
+
1299
+ def get_eval_batches(self) -> Optional[List[List[int]]]:
1300
+ return self.eval_batches
1301
+
1302
+ def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]:
1303
+ return entry["frame_annotation"].meta["frame_type"]
1304
+
1305
+
1306
+ class CO3DDataset(LightningDataModule):
1307
+ def __init__(
1308
+ self,
1309
+ root_dir,
1310
+ batch_size=2,
1311
+ shuffle=True,
1312
+ num_workers=10,
1313
+ prefetch_factor=2,
1314
+ category="hydrant",
1315
+ **kwargs,
1316
+ ):
1317
+ super().__init__()
1318
+
1319
+ self.batch_size = batch_size
1320
+ self.num_workers = num_workers
1321
+ self.prefetch_factor = prefetch_factor
1322
+ self.shuffle = shuffle
1323
+
1324
+ self.train_dataset = CO3Dv2Wrapper(
1325
+ root_dir=root_dir,
1326
+ stage="train",
1327
+ category=category,
1328
+ **kwargs,
1329
+ )
1330
+
1331
+ self.test_dataset = CO3Dv2Wrapper(
1332
+ root_dir=root_dir,
1333
+ stage="test",
1334
+ subset="fewview_dev",
1335
+ category=category,
1336
+ **kwargs,
1337
+ )
1338
+
1339
+ def train_dataloader(self):
1340
+ return DataLoader(
1341
+ self.train_dataset,
1342
+ batch_size=self.batch_size,
1343
+ shuffle=self.shuffle,
1344
+ num_workers=self.num_workers,
1345
+ prefetch_factor=self.prefetch_factor,
1346
+ collate_fn=self.train_dataset.collate_fn,
1347
+ )
1348
+
1349
+ def test_dataloader(self):
1350
+ return DataLoader(
1351
+ self.test_dataset,
1352
+ batch_size=self.batch_size,
1353
+ shuffle=self.shuffle,
1354
+ num_workers=self.num_workers,
1355
+ prefetch_factor=self.prefetch_factor,
1356
+ collate_fn=self.test_dataset.collate_fn,
1357
+ )
1358
+
1359
+ def val_dataloader(self):
1360
+ return DataLoader(
1361
+ self.test_dataset,
1362
+ batch_size=self.batch_size,
1363
+ shuffle=self.shuffle,
1364
+ num_workers=self.num_workers,
1365
+ prefetch_factor=self.prefetch_factor,
1366
+ collate_fn=video_collate_fn,
1367
+ )
sgm/data/colmap.py ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, ETH Zurich and UNC Chapel Hill.
2
+ # All rights reserved.
3
+ #
4
+ # Redistribution and use in source and binary forms, with or without
5
+ # modification, are permitted provided that the following conditions are met:
6
+ #
7
+ # * Redistributions of source code must retain the above copyright
8
+ # notice, this list of conditions and the following disclaimer.
9
+ #
10
+ # * Redistributions in binary form must reproduce the above copyright
11
+ # notice, this list of conditions and the following disclaimer in the
12
+ # documentation and/or other materials provided with the distribution.
13
+ #
14
+ # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
15
+ # its contributors may be used to endorse or promote products derived
16
+ # from this software without specific prior written permission.
17
+ #
18
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21
+ # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
22
+ # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23
+ # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24
+ # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25
+ # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26
+ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27
+ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28
+ # POSSIBILITY OF SUCH DAMAGE.
29
+
30
+
31
+ import os
32
+ import collections
33
+ import numpy as np
34
+ import struct
35
+ import argparse
36
+
37
+
38
+ CameraModel = collections.namedtuple(
39
+ "CameraModel", ["model_id", "model_name", "num_params"]
40
+ )
41
+ Camera = collections.namedtuple(
42
+ "Camera", ["id", "model", "width", "height", "params"]
43
+ )
44
+ BaseImage = collections.namedtuple(
45
+ "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]
46
+ )
47
+ Point3D = collections.namedtuple(
48
+ "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]
49
+ )
50
+
51
+
52
+ class Image(BaseImage):
53
+ def qvec2rotmat(self):
54
+ return qvec2rotmat(self.qvec)
55
+
56
+
57
+ CAMERA_MODELS = {
58
+ CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
59
+ CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
60
+ CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
61
+ CameraModel(model_id=3, model_name="RADIAL", num_params=5),
62
+ CameraModel(model_id=4, model_name="OPENCV", num_params=8),
63
+ CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
64
+ CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
65
+ CameraModel(model_id=7, model_name="FOV", num_params=5),
66
+ CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
67
+ CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
68
+ CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12),
69
+ }
70
+ CAMERA_MODEL_IDS = dict(
71
+ [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS]
72
+ )
73
+ CAMERA_MODEL_NAMES = dict(
74
+ [(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS]
75
+ )
76
+
77
+
78
+ def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
79
+ """Read and unpack the next bytes from a binary file.
80
+ :param fid:
81
+ :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
82
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
83
+ :param endian_character: Any of {@, =, <, >, !}
84
+ :return: Tuple of read and unpacked values.
85
+ """
86
+ data = fid.read(num_bytes)
87
+ return struct.unpack(endian_character + format_char_sequence, data)
88
+
89
+
90
+ def write_next_bytes(fid, data, format_char_sequence, endian_character="<"):
91
+ """pack and write to a binary file.
92
+ :param fid:
93
+ :param data: data to send, if multiple elements are sent at the same time,
94
+ they should be encapsuled either in a list or a tuple
95
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
96
+ should be the same length as the data list or tuple
97
+ :param endian_character: Any of {@, =, <, >, !}
98
+ """
99
+ if isinstance(data, (list, tuple)):
100
+ bytes = struct.pack(endian_character + format_char_sequence, *data)
101
+ else:
102
+ bytes = struct.pack(endian_character + format_char_sequence, data)
103
+ fid.write(bytes)
104
+
105
+
106
+ def read_cameras_text(path):
107
+ """
108
+ see: src/colmap/scene/reconstruction.cc
109
+ void Reconstruction::WriteCamerasText(const std::string& path)
110
+ void Reconstruction::ReadCamerasText(const std::string& path)
111
+ """
112
+ cameras = {}
113
+ with open(path, "r") as fid:
114
+ while True:
115
+ line = fid.readline()
116
+ if not line:
117
+ break
118
+ line = line.strip()
119
+ if len(line) > 0 and line[0] != "#":
120
+ elems = line.split()
121
+ camera_id = int(elems[0])
122
+ model = elems[1]
123
+ width = int(elems[2])
124
+ height = int(elems[3])
125
+ params = np.array(tuple(map(float, elems[4:])))
126
+ cameras[camera_id] = Camera(
127
+ id=camera_id,
128
+ model=model,
129
+ width=width,
130
+ height=height,
131
+ params=params,
132
+ )
133
+ return cameras
134
+
135
+
136
+ def read_cameras_binary(path_to_model_file):
137
+ """
138
+ see: src/colmap/scene/reconstruction.cc
139
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
140
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
141
+ """
142
+ cameras = {}
143
+ with open(path_to_model_file, "rb") as fid:
144
+ num_cameras = read_next_bytes(fid, 8, "Q")[0]
145
+ for _ in range(num_cameras):
146
+ camera_properties = read_next_bytes(
147
+ fid, num_bytes=24, format_char_sequence="iiQQ"
148
+ )
149
+ camera_id = camera_properties[0]
150
+ model_id = camera_properties[1]
151
+ model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
152
+ width = camera_properties[2]
153
+ height = camera_properties[3]
154
+ num_params = CAMERA_MODEL_IDS[model_id].num_params
155
+ params = read_next_bytes(
156
+ fid,
157
+ num_bytes=8 * num_params,
158
+ format_char_sequence="d" * num_params,
159
+ )
160
+ cameras[camera_id] = Camera(
161
+ id=camera_id,
162
+ model=model_name,
163
+ width=width,
164
+ height=height,
165
+ params=np.array(params),
166
+ )
167
+ assert len(cameras) == num_cameras
168
+ return cameras
169
+
170
+
171
+ def write_cameras_text(cameras, path):
172
+ """
173
+ see: src/colmap/scene/reconstruction.cc
174
+ void Reconstruction::WriteCamerasText(const std::string& path)
175
+ void Reconstruction::ReadCamerasText(const std::string& path)
176
+ """
177
+ HEADER = (
178
+ "# Camera list with one line of data per camera:\n"
179
+ + "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n"
180
+ + "# Number of cameras: {}\n".format(len(cameras))
181
+ )
182
+ with open(path, "w") as fid:
183
+ fid.write(HEADER)
184
+ for _, cam in cameras.items():
185
+ to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params]
186
+ line = " ".join([str(elem) for elem in to_write])
187
+ fid.write(line + "\n")
188
+
189
+
190
+ def write_cameras_binary(cameras, path_to_model_file):
191
+ """
192
+ see: src/colmap/scene/reconstruction.cc
193
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
194
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
195
+ """
196
+ with open(path_to_model_file, "wb") as fid:
197
+ write_next_bytes(fid, len(cameras), "Q")
198
+ for _, cam in cameras.items():
199
+ model_id = CAMERA_MODEL_NAMES[cam.model].model_id
200
+ camera_properties = [cam.id, model_id, cam.width, cam.height]
201
+ write_next_bytes(fid, camera_properties, "iiQQ")
202
+ for p in cam.params:
203
+ write_next_bytes(fid, float(p), "d")
204
+ return cameras
205
+
206
+
207
+ def read_images_text(path):
208
+ """
209
+ see: src/colmap/scene/reconstruction.cc
210
+ void Reconstruction::ReadImagesText(const std::string& path)
211
+ void Reconstruction::WriteImagesText(const std::string& path)
212
+ """
213
+ images = {}
214
+ with open(path, "r") as fid:
215
+ while True:
216
+ line = fid.readline()
217
+ if not line:
218
+ break
219
+ line = line.strip()
220
+ if len(line) > 0 and line[0] != "#":
221
+ elems = line.split()
222
+ image_id = int(elems[0])
223
+ qvec = np.array(tuple(map(float, elems[1:5])))
224
+ tvec = np.array(tuple(map(float, elems[5:8])))
225
+ camera_id = int(elems[8])
226
+ image_name = elems[9]
227
+ elems = fid.readline().split()
228
+ xys = np.column_stack(
229
+ [
230
+ tuple(map(float, elems[0::3])),
231
+ tuple(map(float, elems[1::3])),
232
+ ]
233
+ )
234
+ point3D_ids = np.array(tuple(map(int, elems[2::3])))
235
+ images[image_id] = Image(
236
+ id=image_id,
237
+ qvec=qvec,
238
+ tvec=tvec,
239
+ camera_id=camera_id,
240
+ name=image_name,
241
+ xys=xys,
242
+ point3D_ids=point3D_ids,
243
+ )
244
+ return images
245
+
246
+
247
+ def read_images_binary(path_to_model_file):
248
+ """
249
+ see: src/colmap/scene/reconstruction.cc
250
+ void Reconstruction::ReadImagesBinary(const std::string& path)
251
+ void Reconstruction::WriteImagesBinary(const std::string& path)
252
+ """
253
+ images = {}
254
+ with open(path_to_model_file, "rb") as fid:
255
+ num_reg_images = read_next_bytes(fid, 8, "Q")[0]
256
+ for _ in range(num_reg_images):
257
+ binary_image_properties = read_next_bytes(
258
+ fid, num_bytes=64, format_char_sequence="idddddddi"
259
+ )
260
+ image_id = binary_image_properties[0]
261
+ qvec = np.array(binary_image_properties[1:5])
262
+ tvec = np.array(binary_image_properties[5:8])
263
+ camera_id = binary_image_properties[8]
264
+ binary_image_name = b""
265
+ current_char = read_next_bytes(fid, 1, "c")[0]
266
+ while current_char != b"\x00": # look for the ASCII 0 entry
267
+ binary_image_name += current_char
268
+ current_char = read_next_bytes(fid, 1, "c")[0]
269
+ image_name = binary_image_name.decode("utf-8")
270
+ num_points2D = read_next_bytes(
271
+ fid, num_bytes=8, format_char_sequence="Q"
272
+ )[0]
273
+ x_y_id_s = read_next_bytes(
274
+ fid,
275
+ num_bytes=24 * num_points2D,
276
+ format_char_sequence="ddq" * num_points2D,
277
+ )
278
+ xys = np.column_stack(
279
+ [
280
+ tuple(map(float, x_y_id_s[0::3])),
281
+ tuple(map(float, x_y_id_s[1::3])),
282
+ ]
283
+ )
284
+ point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
285
+ images[image_id] = Image(
286
+ id=image_id,
287
+ qvec=qvec,
288
+ tvec=tvec,
289
+ camera_id=camera_id,
290
+ name=image_name,
291
+ xys=xys,
292
+ point3D_ids=point3D_ids,
293
+ )
294
+ return images
295
+
296
+
297
+ def write_images_text(images, path):
298
+ """
299
+ see: src/colmap/scene/reconstruction.cc
300
+ void Reconstruction::ReadImagesText(const std::string& path)
301
+ void Reconstruction::WriteImagesText(const std::string& path)
302
+ """
303
+ if len(images) == 0:
304
+ mean_observations = 0
305
+ else:
306
+ mean_observations = sum(
307
+ (len(img.point3D_ids) for _, img in images.items())
308
+ ) / len(images)
309
+ HEADER = (
310
+ "# Image list with two lines of data per image:\n"
311
+ + "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n"
312
+ + "# POINTS2D[] as (X, Y, POINT3D_ID)\n"
313
+ + "# Number of images: {}, mean observations per image: {}\n".format(
314
+ len(images), mean_observations
315
+ )
316
+ )
317
+
318
+ with open(path, "w") as fid:
319
+ fid.write(HEADER)
320
+ for _, img in images.items():
321
+ image_header = [
322
+ img.id,
323
+ *img.qvec,
324
+ *img.tvec,
325
+ img.camera_id,
326
+ img.name,
327
+ ]
328
+ first_line = " ".join(map(str, image_header))
329
+ fid.write(first_line + "\n")
330
+
331
+ points_strings = []
332
+ for xy, point3D_id in zip(img.xys, img.point3D_ids):
333
+ points_strings.append(" ".join(map(str, [*xy, point3D_id])))
334
+ fid.write(" ".join(points_strings) + "\n")
335
+
336
+
337
+ def write_images_binary(images, path_to_model_file):
338
+ """
339
+ see: src/colmap/scene/reconstruction.cc
340
+ void Reconstruction::ReadImagesBinary(const std::string& path)
341
+ void Reconstruction::WriteImagesBinary(const std::string& path)
342
+ """
343
+ with open(path_to_model_file, "wb") as fid:
344
+ write_next_bytes(fid, len(images), "Q")
345
+ for _, img in images.items():
346
+ write_next_bytes(fid, img.id, "i")
347
+ write_next_bytes(fid, img.qvec.tolist(), "dddd")
348
+ write_next_bytes(fid, img.tvec.tolist(), "ddd")
349
+ write_next_bytes(fid, img.camera_id, "i")
350
+ for char in img.name:
351
+ write_next_bytes(fid, char.encode("utf-8"), "c")
352
+ write_next_bytes(fid, b"\x00", "c")
353
+ write_next_bytes(fid, len(img.point3D_ids), "Q")
354
+ for xy, p3d_id in zip(img.xys, img.point3D_ids):
355
+ write_next_bytes(fid, [*xy, p3d_id], "ddq")
356
+
357
+
358
+ def read_points3D_text(path):
359
+ """
360
+ see: src/colmap/scene/reconstruction.cc
361
+ void Reconstruction::ReadPoints3DText(const std::string& path)
362
+ void Reconstruction::WritePoints3DText(const std::string& path)
363
+ """
364
+ points3D = {}
365
+ with open(path, "r") as fid:
366
+ while True:
367
+ line = fid.readline()
368
+ if not line:
369
+ break
370
+ line = line.strip()
371
+ if len(line) > 0 and line[0] != "#":
372
+ elems = line.split()
373
+ point3D_id = int(elems[0])
374
+ xyz = np.array(tuple(map(float, elems[1:4])))
375
+ rgb = np.array(tuple(map(int, elems[4:7])))
376
+ error = float(elems[7])
377
+ image_ids = np.array(tuple(map(int, elems[8::2])))
378
+ point2D_idxs = np.array(tuple(map(int, elems[9::2])))
379
+ points3D[point3D_id] = Point3D(
380
+ id=point3D_id,
381
+ xyz=xyz,
382
+ rgb=rgb,
383
+ error=error,
384
+ image_ids=image_ids,
385
+ point2D_idxs=point2D_idxs,
386
+ )
387
+ return points3D
388
+
389
+
390
+ def read_points3D_binary(path_to_model_file):
391
+ """
392
+ see: src/colmap/scene/reconstruction.cc
393
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
394
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
395
+ """
396
+ points3D = {}
397
+ with open(path_to_model_file, "rb") as fid:
398
+ num_points = read_next_bytes(fid, 8, "Q")[0]
399
+ for _ in range(num_points):
400
+ binary_point_line_properties = read_next_bytes(
401
+ fid, num_bytes=43, format_char_sequence="QdddBBBd"
402
+ )
403
+ point3D_id = binary_point_line_properties[0]
404
+ xyz = np.array(binary_point_line_properties[1:4])
405
+ rgb = np.array(binary_point_line_properties[4:7])
406
+ error = np.array(binary_point_line_properties[7])
407
+ track_length = read_next_bytes(
408
+ fid, num_bytes=8, format_char_sequence="Q"
409
+ )[0]
410
+ track_elems = read_next_bytes(
411
+ fid,
412
+ num_bytes=8 * track_length,
413
+ format_char_sequence="ii" * track_length,
414
+ )
415
+ image_ids = np.array(tuple(map(int, track_elems[0::2])))
416
+ point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
417
+ points3D[point3D_id] = Point3D(
418
+ id=point3D_id,
419
+ xyz=xyz,
420
+ rgb=rgb,
421
+ error=error,
422
+ image_ids=image_ids,
423
+ point2D_idxs=point2D_idxs,
424
+ )
425
+ return points3D
426
+
427
+
428
+ def write_points3D_text(points3D, path):
429
+ """
430
+ see: src/colmap/scene/reconstruction.cc
431
+ void Reconstruction::ReadPoints3DText(const std::string& path)
432
+ void Reconstruction::WritePoints3DText(const std::string& path)
433
+ """
434
+ if len(points3D) == 0:
435
+ mean_track_length = 0
436
+ else:
437
+ mean_track_length = sum(
438
+ (len(pt.image_ids) for _, pt in points3D.items())
439
+ ) / len(points3D)
440
+ HEADER = (
441
+ "# 3D point list with one line of data per point:\n"
442
+ + "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n"
443
+ + "# Number of points: {}, mean track length: {}\n".format(
444
+ len(points3D), mean_track_length
445
+ )
446
+ )
447
+
448
+ with open(path, "w") as fid:
449
+ fid.write(HEADER)
450
+ for _, pt in points3D.items():
451
+ point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error]
452
+ fid.write(" ".join(map(str, point_header)) + " ")
453
+ track_strings = []
454
+ for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs):
455
+ track_strings.append(" ".join(map(str, [image_id, point2D])))
456
+ fid.write(" ".join(track_strings) + "\n")
457
+
458
+
459
+ def write_points3D_binary(points3D, path_to_model_file):
460
+ """
461
+ see: src/colmap/scene/reconstruction.cc
462
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
463
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
464
+ """
465
+ with open(path_to_model_file, "wb") as fid:
466
+ write_next_bytes(fid, len(points3D), "Q")
467
+ for _, pt in points3D.items():
468
+ write_next_bytes(fid, pt.id, "Q")
469
+ write_next_bytes(fid, pt.xyz.tolist(), "ddd")
470
+ write_next_bytes(fid, pt.rgb.tolist(), "BBB")
471
+ write_next_bytes(fid, pt.error, "d")
472
+ track_length = pt.image_ids.shape[0]
473
+ write_next_bytes(fid, track_length, "Q")
474
+ for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):
475
+ write_next_bytes(fid, [image_id, point2D_id], "ii")
476
+
477
+
478
+ def detect_model_format(path, ext):
479
+ if (
480
+ os.path.isfile(os.path.join(path, "cameras" + ext))
481
+ and os.path.isfile(os.path.join(path, "images" + ext))
482
+ and os.path.isfile(os.path.join(path, "points3D" + ext))
483
+ ):
484
+ print("Detected model format: '" + ext + "'")
485
+ return True
486
+
487
+ return False
488
+
489
+
490
+ def read_model(path, ext=""):
491
+ # try to detect the extension automatically
492
+ if ext == "":
493
+ if detect_model_format(path, ".bin"):
494
+ ext = ".bin"
495
+ elif detect_model_format(path, ".txt"):
496
+ ext = ".txt"
497
+ else:
498
+ print("Provide model format: '.bin' or '.txt'")
499
+ return
500
+
501
+ if ext == ".txt":
502
+ cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
503
+ images = read_images_text(os.path.join(path, "images" + ext))
504
+ points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
505
+ else:
506
+ cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
507
+ images = read_images_binary(os.path.join(path, "images" + ext))
508
+ points3D = read_points3D_binary(os.path.join(path, "points3D") + ext)
509
+ return cameras, images, points3D
510
+
511
+
512
+ def write_model(cameras, images, points3D, path, ext=".bin"):
513
+ if ext == ".txt":
514
+ write_cameras_text(cameras, os.path.join(path, "cameras" + ext))
515
+ write_images_text(images, os.path.join(path, "images" + ext))
516
+ write_points3D_text(points3D, os.path.join(path, "points3D") + ext)
517
+ else:
518
+ write_cameras_binary(cameras, os.path.join(path, "cameras" + ext))
519
+ write_images_binary(images, os.path.join(path, "images" + ext))
520
+ write_points3D_binary(points3D, os.path.join(path, "points3D") + ext)
521
+ return cameras, images, points3D
522
+
523
+
524
+ def qvec2rotmat(qvec):
525
+ return np.array(
526
+ [
527
+ [
528
+ 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
529
+ 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
530
+ 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],
531
+ ],
532
+ [
533
+ 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
534
+ 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
535
+ 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],
536
+ ],
537
+ [
538
+ 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
539
+ 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
540
+ 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,
541
+ ],
542
+ ]
543
+ )
544
+
545
+
546
+ def rotmat2qvec(R):
547
+ Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
548
+ K = (
549
+ np.array(
550
+ [
551
+ [Rxx - Ryy - Rzz, 0, 0, 0],
552
+ [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
553
+ [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
554
+ [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz],
555
+ ]
556
+ )
557
+ / 3.0
558
+ )
559
+ eigvals, eigvecs = np.linalg.eigh(K)
560
+ qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
561
+ if qvec[0] < 0:
562
+ qvec *= -1
563
+ return qvec
564
+
565
+
566
+ def main():
567
+ parser = argparse.ArgumentParser(
568
+ description="Read and write COLMAP binary and text models"
569
+ )
570
+ parser.add_argument("--input_model", help="path to input model folder")
571
+ parser.add_argument(
572
+ "--input_format",
573
+ choices=[".bin", ".txt"],
574
+ help="input model format",
575
+ default="",
576
+ )
577
+ parser.add_argument("--output_model", help="path to output model folder")
578
+ parser.add_argument(
579
+ "--output_format",
580
+ choices=[".bin", ".txt"],
581
+ help="outut model format",
582
+ default=".txt",
583
+ )
584
+ args = parser.parse_args()
585
+
586
+ cameras, images, points3D = read_model(
587
+ path=args.input_model, ext=args.input_format
588
+ )
589
+
590
+ print("num_cameras:", len(cameras))
591
+ print("num_images:", len(images))
592
+ print("num_points3D:", len(points3D))
593
+
594
+ if args.output_model is not None:
595
+ write_model(
596
+ cameras,
597
+ images,
598
+ points3D,
599
+ path=args.output_model,
600
+ ext=args.output_format,
601
+ )
602
+
603
+
604
+ if __name__ == "__main__":
605
+ main()
sgm/data/dataset.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torchdata.datapipes.iter
4
+ import webdataset as wds
5
+ from omegaconf import DictConfig
6
+ from pytorch_lightning import LightningDataModule
7
+
8
+ try:
9
+ from sdata import create_dataset, create_dummy_dataset, create_loader
10
+ except ImportError as e:
11
+ print("#" * 100)
12
+ print("Datasets not yet available")
13
+ print("to enable, we need to add stable-datasets as a submodule")
14
+ print("please use ``git submodule update --init --recursive``")
15
+ print("and do ``pip install -e stable-datasets/`` from the root of this repo")
16
+ print("#" * 100)
17
+ exit(1)
18
+
19
+
20
+ class StableDataModuleFromConfig(LightningDataModule):
21
+ def __init__(
22
+ self,
23
+ train: DictConfig,
24
+ validation: Optional[DictConfig] = None,
25
+ test: Optional[DictConfig] = None,
26
+ skip_val_loader: bool = False,
27
+ dummy: bool = False,
28
+ ):
29
+ super().__init__()
30
+ self.train_config = train
31
+ assert (
32
+ "datapipeline" in self.train_config and "loader" in self.train_config
33
+ ), "train config requires the fields `datapipeline` and `loader`"
34
+
35
+ self.val_config = validation
36
+ if not skip_val_loader:
37
+ if self.val_config is not None:
38
+ assert (
39
+ "datapipeline" in self.val_config and "loader" in self.val_config
40
+ ), "validation config requires the fields `datapipeline` and `loader`"
41
+ else:
42
+ print(
43
+ "Warning: No Validation datapipeline defined, using that one from training"
44
+ )
45
+ self.val_config = train
46
+
47
+ self.test_config = test
48
+ if self.test_config is not None:
49
+ assert (
50
+ "datapipeline" in self.test_config and "loader" in self.test_config
51
+ ), "test config requires the fields `datapipeline` and `loader`"
52
+
53
+ self.dummy = dummy
54
+ if self.dummy:
55
+ print("#" * 100)
56
+ print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
57
+ print("#" * 100)
58
+
59
+ def setup(self, stage: str) -> None:
60
+ print("Preparing datasets")
61
+ if self.dummy:
62
+ data_fn = create_dummy_dataset
63
+ else:
64
+ data_fn = create_dataset
65
+
66
+ self.train_datapipeline = data_fn(**self.train_config.datapipeline)
67
+ if self.val_config:
68
+ self.val_datapipeline = data_fn(**self.val_config.datapipeline)
69
+ if self.test_config:
70
+ self.test_datapipeline = data_fn(**self.test_config.datapipeline)
71
+
72
+ def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
73
+ loader = create_loader(self.train_datapipeline, **self.train_config.loader)
74
+ return loader
75
+
76
+ def val_dataloader(self) -> wds.DataPipeline:
77
+ return create_loader(self.val_datapipeline, **self.val_config.loader)
78
+
79
+ def test_dataloader(self) -> wds.DataPipeline:
80
+ return create_loader(self.test_datapipeline, **self.test_config.loader)
sgm/data/joint3d.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+
4
+ default_sub_data_config = {}
5
+
6
+
7
+ class Joint3D(Dataset):
8
+ def __init__(self, sub_data_config: dict) -> None:
9
+ super().__init__()
10
+ self.sub_data_config = sub_data_config
sgm/data/json_index_dataset.py ADDED
@@ -0,0 +1,1080 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import copy
8
+ import functools
9
+ import gzip
10
+ import hashlib
11
+ import json
12
+ import logging
13
+ import os
14
+ import random
15
+ import warnings
16
+ from collections import defaultdict
17
+ from itertools import islice
18
+ from pathlib import Path
19
+ from typing import (
20
+ Any,
21
+ ClassVar,
22
+ Dict,
23
+ Iterable,
24
+ List,
25
+ Optional,
26
+ Sequence,
27
+ Tuple,
28
+ Type,
29
+ TYPE_CHECKING,
30
+ Union,
31
+ )
32
+
33
+ import numpy as np
34
+ import torch
35
+ from PIL import Image
36
+ from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
37
+ from pytorch3d.io import IO
38
+ from pytorch3d.renderer.camera_utils import join_cameras_as_batch
39
+ from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
40
+ from pytorch3d.structures.pointclouds import Pointclouds
41
+ from tqdm import tqdm
42
+
43
+ from pytorch3d.implicitron.dataset import types
44
+ from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
45
+ from pytorch3d.implicitron.dataset.utils import is_known_frame_scalar
46
+
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ if TYPE_CHECKING:
52
+ from typing import TypedDict
53
+
54
+ class FrameAnnotsEntry(TypedDict):
55
+ subset: Optional[str]
56
+ frame_annotation: types.FrameAnnotation
57
+
58
+ else:
59
+ FrameAnnotsEntry = dict
60
+
61
+
62
+ @registry.register
63
+ class JsonIndexDataset(DatasetBase, ReplaceableBase):
64
+ """
65
+ A dataset with annotations in json files like the Common Objects in 3D
66
+ (CO3D) dataset.
67
+
68
+ Args:
69
+ frame_annotations_file: A zipped json file containing metadata of the
70
+ frames in the dataset, serialized List[types.FrameAnnotation].
71
+ sequence_annotations_file: A zipped json file containing metadata of the
72
+ sequences in the dataset, serialized List[types.SequenceAnnotation].
73
+ subset_lists_file: A json file containing the lists of frames corresponding
74
+ corresponding to different subsets (e.g. train/val/test) of the dataset;
75
+ format: {subset: (sequence_name, frame_id, file_path)}.
76
+ subsets: Restrict frames/sequences only to the given list of subsets
77
+ as defined in subset_lists_file (see above).
78
+ limit_to: Limit the dataset to the first #limit_to frames (after other
79
+ filters have been applied).
80
+ limit_sequences_to: Limit the dataset to the first
81
+ #limit_sequences_to sequences (after other sequence filters have been
82
+ applied but before frame-based filters).
83
+ pick_sequence: A list of sequence names to restrict the dataset to.
84
+ exclude_sequence: A list of the names of the sequences to exclude.
85
+ limit_category_to: Restrict the dataset to the given list of categories.
86
+ dataset_root: The root folder of the dataset; all the paths in jsons are
87
+ specified relative to this root (but not json paths themselves).
88
+ load_images: Enable loading the frame RGB data.
89
+ load_depths: Enable loading the frame depth maps.
90
+ load_depth_masks: Enable loading the frame depth map masks denoting the
91
+ depth values used for evaluation (the points consistent across views).
92
+ load_masks: Enable loading frame foreground masks.
93
+ load_point_clouds: Enable loading sequence-level point clouds.
94
+ max_points: Cap on the number of loaded points in the point cloud;
95
+ if reached, they are randomly sampled without replacement.
96
+ mask_images: Whether to mask the images with the loaded foreground masks;
97
+ 0 value is used for background.
98
+ mask_depths: Whether to mask the depth maps with the loaded foreground
99
+ masks; 0 value is used for background.
100
+ image_height: The height of the returned images, masks, and depth maps;
101
+ aspect ratio is preserved during cropping/resizing.
102
+ image_width: The width of the returned images, masks, and depth maps;
103
+ aspect ratio is preserved during cropping/resizing.
104
+ box_crop: Enable cropping of the image around the bounding box inferred
105
+ from the foreground region of the loaded segmentation mask; masks
106
+ and depth maps are cropped accordingly; cameras are corrected.
107
+ box_crop_mask_thr: The threshold used to separate pixels into foreground
108
+ and background based on the foreground_probability mask; if no value
109
+ is greater than this threshold, the loader lowers it and repeats.
110
+ box_crop_context: The amount of additional padding added to each
111
+ dimension of the cropping bounding box, relative to box size.
112
+ remove_empty_masks: Removes the frames with no active foreground pixels
113
+ in the segmentation mask after thresholding (see box_crop_mask_thr).
114
+ n_frames_per_sequence: If > 0, randomly samples #n_frames_per_sequence
115
+ frames in each sequences uniformly without replacement if it has
116
+ more frames than that; applied before other frame-level filters.
117
+ seed: The seed of the random generator sampling #n_frames_per_sequence
118
+ random frames per sequence.
119
+ sort_frames: Enable frame annotations sorting to group frames from the
120
+ same sequences together and order them by timestamps
121
+ eval_batches: A list of batches that form the evaluation set;
122
+ list of batch-sized lists of indices corresponding to __getitem__
123
+ of this class, thus it can be used directly as a batch sampler.
124
+ eval_batch_index:
125
+ ( Optional[List[List[Union[Tuple[str, int, str], Tuple[str, int]]]] )
126
+ A list of batches of frames described as (sequence_name, frame_idx)
127
+ that can form the evaluation set, `eval_batches` will be set from this.
128
+
129
+ """
130
+
131
+ frame_annotations_type: ClassVar[
132
+ Type[types.FrameAnnotation]
133
+ ] = types.FrameAnnotation
134
+
135
+ path_manager: Any = None
136
+ frame_annotations_file: str = ""
137
+ sequence_annotations_file: str = ""
138
+ subset_lists_file: str = ""
139
+ subsets: Optional[List[str]] = None
140
+ limit_to: int = 0
141
+ limit_sequences_to: int = 0
142
+ pick_sequence: Tuple[str, ...] = ()
143
+ exclude_sequence: Tuple[str, ...] = ()
144
+ limit_category_to: Tuple[int, ...] = ()
145
+ dataset_root: str = ""
146
+ load_images: bool = True
147
+ load_depths: bool = True
148
+ load_depth_masks: bool = True
149
+ load_masks: bool = True
150
+ load_point_clouds: bool = False
151
+ max_points: int = 0
152
+ mask_images: bool = False
153
+ mask_depths: bool = False
154
+ image_height: Optional[int] = 800
155
+ image_width: Optional[int] = 800
156
+ box_crop: bool = True
157
+ box_crop_mask_thr: float = 0.4
158
+ box_crop_context: float = 0.3
159
+ remove_empty_masks: bool = True
160
+ n_frames_per_sequence: int = -1
161
+ seed: int = 0
162
+ sort_frames: bool = False
163
+ eval_batches: Any = None
164
+ eval_batch_index: Any = None
165
+ # frame_annots: List[FrameAnnotsEntry] = field(init=False)
166
+ # seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
167
+
168
+ def __post_init__(self) -> None:
169
+ # pyre-fixme[16]: `JsonIndexDataset` has no attribute `subset_to_image_path`.
170
+ self.subset_to_image_path = None
171
+ self._load_frames()
172
+ self._load_sequences()
173
+ if self.sort_frames:
174
+ self._sort_frames()
175
+ self._load_subset_lists()
176
+ self._filter_db() # also computes sequence indices
177
+ self._extract_and_set_eval_batches()
178
+ logger.info(str(self))
179
+
180
+ def _extract_and_set_eval_batches(self):
181
+ """
182
+ Sets eval_batches based on input eval_batch_index.
183
+ """
184
+ if self.eval_batch_index is not None:
185
+ if self.eval_batches is not None:
186
+ raise ValueError(
187
+ "Cannot define both eval_batch_index and eval_batches."
188
+ )
189
+ self.eval_batches = self.seq_frame_index_to_dataset_index(
190
+ self.eval_batch_index
191
+ )
192
+
193
+ def join(self, other_datasets: Iterable[DatasetBase]) -> None:
194
+ """
195
+ Join the dataset with other JsonIndexDataset objects.
196
+
197
+ Args:
198
+ other_datasets: A list of JsonIndexDataset objects to be joined
199
+ into the current dataset.
200
+ """
201
+ if not all(isinstance(d, JsonIndexDataset) for d in other_datasets):
202
+ raise ValueError("This function can only join a list of JsonIndexDataset")
203
+ # pyre-ignore[16]
204
+ self.frame_annots.extend([fa for d in other_datasets for fa in d.frame_annots])
205
+ # pyre-ignore[16]
206
+ self.seq_annots.update(
207
+ # https://gist.github.com/treyhunner/f35292e676efa0be1728
208
+ functools.reduce(
209
+ lambda a, b: {**a, **b},
210
+ [d.seq_annots for d in other_datasets], # pyre-ignore[16]
211
+ )
212
+ )
213
+ all_eval_batches = [
214
+ self.eval_batches,
215
+ # pyre-ignore
216
+ *[d.eval_batches for d in other_datasets],
217
+ ]
218
+ if not (
219
+ all(ba is None for ba in all_eval_batches)
220
+ or all(ba is not None for ba in all_eval_batches)
221
+ ):
222
+ raise ValueError(
223
+ "When joining datasets, either all joined datasets have to have their"
224
+ " eval_batches defined, or all should have their eval batches undefined."
225
+ )
226
+ if self.eval_batches is not None:
227
+ self.eval_batches = sum(all_eval_batches, [])
228
+ self._invalidate_indexes(filter_seq_annots=True)
229
+
230
+ def is_filtered(self) -> bool:
231
+ """
232
+ Returns `True` in case the dataset has been filtered and thus some frame annotations
233
+ stored on the disk might be missing in the dataset object.
234
+
235
+ Returns:
236
+ is_filtered: `True` if the dataset has been filtered, else `False`.
237
+ """
238
+ return (
239
+ self.remove_empty_masks
240
+ or self.limit_to > 0
241
+ or self.limit_sequences_to > 0
242
+ or len(self.pick_sequence) > 0
243
+ or len(self.exclude_sequence) > 0
244
+ or len(self.limit_category_to) > 0
245
+ or self.n_frames_per_sequence > 0
246
+ )
247
+
248
+ def seq_frame_index_to_dataset_index(
249
+ self,
250
+ seq_frame_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
251
+ allow_missing_indices: bool = False,
252
+ remove_missing_indices: bool = False,
253
+ suppress_missing_index_warning: bool = True,
254
+ ) -> List[List[Union[Optional[int], int]]]:
255
+ """
256
+ Obtain indices into the dataset object given a list of frame ids.
257
+
258
+ Args:
259
+ seq_frame_index: The list of frame ids specified as
260
+ `List[List[Tuple[sequence_name:str, frame_number:int]]]`. Optionally,
261
+ Image paths relative to the dataset_root can be stored specified as well:
262
+ `List[List[Tuple[sequence_name:str, frame_number:int, image_path:str]]]`
263
+ allow_missing_indices: If `False`, throws an IndexError upon reaching the first
264
+ entry from `seq_frame_index` which is missing in the dataset.
265
+ Otherwise, depending on `remove_missing_indices`, either returns `None`
266
+ in place of missing entries or removes the indices of missing entries.
267
+ remove_missing_indices: Active when `allow_missing_indices=True`.
268
+ If `False`, returns `None` in place of `seq_frame_index` entries that
269
+ are not present in the dataset.
270
+ If `True` removes missing indices from the returned indices.
271
+ suppress_missing_index_warning:
272
+ Active if `allow_missing_indices==True`. Suppressess a warning message
273
+ in case an entry from `seq_frame_index` is missing in the dataset
274
+ (expected in certain cases - e.g. when setting
275
+ `self.remove_empty_masks=True`).
276
+
277
+ Returns:
278
+ dataset_idx: Indices of dataset entries corresponding to`seq_frame_index`.
279
+ """
280
+ _dataset_seq_frame_n_index = {
281
+ seq: {
282
+ # pyre-ignore[16]
283
+ self.frame_annots[idx]["frame_annotation"].frame_number: idx
284
+ for idx in seq_idx
285
+ }
286
+ # pyre-ignore[16]
287
+ for seq, seq_idx in self._seq_to_idx.items()
288
+ }
289
+
290
+ def _get_dataset_idx(
291
+ seq_name: str, frame_no: int, path: Optional[str] = None
292
+ ) -> Optional[int]:
293
+ idx_seq = _dataset_seq_frame_n_index.get(seq_name, None)
294
+ idx = idx_seq.get(frame_no, None) if idx_seq is not None else None
295
+ if idx is None:
296
+ msg = (
297
+ f"sequence_name={seq_name} / frame_number={frame_no}"
298
+ " not in the dataset!"
299
+ )
300
+ if not allow_missing_indices:
301
+ raise IndexError(msg)
302
+ if not suppress_missing_index_warning:
303
+ warnings.warn(msg)
304
+ return idx
305
+ if path is not None:
306
+ # Check that the loaded frame path is consistent
307
+ # with the one stored in self.frame_annots.
308
+ assert os.path.normpath(
309
+ # pyre-ignore[16]
310
+ self.frame_annots[idx]["frame_annotation"].image.path
311
+ ) == os.path.normpath(
312
+ path
313
+ ), f"Inconsistent frame indices {seq_name, frame_no, path}."
314
+ return idx
315
+
316
+ dataset_idx = [
317
+ [_get_dataset_idx(*b) for b in batch] # pyre-ignore [6]
318
+ for batch in seq_frame_index
319
+ ]
320
+
321
+ if allow_missing_indices and remove_missing_indices:
322
+ # remove all None indices, and also batches with only None entries
323
+ valid_dataset_idx = [
324
+ [b for b in batch if b is not None] for batch in dataset_idx
325
+ ]
326
+ return [ # pyre-ignore[7]
327
+ batch for batch in valid_dataset_idx if len(batch) > 0
328
+ ]
329
+
330
+ return dataset_idx
331
+
332
+ def subset_from_frame_index(
333
+ self,
334
+ frame_index: List[Union[Tuple[str, int], Tuple[str, int, str]]],
335
+ allow_missing_indices: bool = True,
336
+ ) -> "JsonIndexDataset":
337
+ """
338
+ Generate a dataset subset given the list of frames specified in `frame_index`.
339
+
340
+ Args:
341
+ frame_index: The list of frame indentifiers (as stored in the metadata)
342
+ specified as `List[Tuple[sequence_name:str, frame_number:int]]`. Optionally,
343
+ Image paths relative to the dataset_root can be stored specified as well:
344
+ `List[Tuple[sequence_name:str, frame_number:int, image_path:str]]`,
345
+ in the latter case, if imaga_path do not match the stored paths, an error
346
+ is raised.
347
+ allow_missing_indices: If `False`, throws an IndexError upon reaching the first
348
+ entry from `frame_index` which is missing in the dataset.
349
+ Otherwise, generates a subset consisting of frames entries that actually
350
+ exist in the dataset.
351
+ """
352
+ # Get the indices into the frame annots.
353
+ dataset_indices = self.seq_frame_index_to_dataset_index(
354
+ [frame_index],
355
+ allow_missing_indices=self.is_filtered() and allow_missing_indices,
356
+ )[0]
357
+ valid_dataset_indices = [i for i in dataset_indices if i is not None]
358
+
359
+ # Deep copy the whole dataset except frame_annots, which are large so we
360
+ # deep copy only the requested subset of frame_annots.
361
+ memo = {id(self.frame_annots): None} # pyre-ignore[16]
362
+ dataset_new = copy.deepcopy(self, memo)
363
+ dataset_new.frame_annots = copy.deepcopy(
364
+ [self.frame_annots[i] for i in valid_dataset_indices]
365
+ )
366
+
367
+ # This will kill all unneeded sequence annotations.
368
+ dataset_new._invalidate_indexes(filter_seq_annots=True)
369
+
370
+ # Finally annotate the frame annotations with the name of the subset
371
+ # stored in meta.
372
+ for frame_annot in dataset_new.frame_annots:
373
+ frame_annotation = frame_annot["frame_annotation"]
374
+ if frame_annotation.meta is not None:
375
+ frame_annot["subset"] = frame_annotation.meta.get("frame_type", None)
376
+
377
+ # A sanity check - this will crash in case some entries from frame_index are missing
378
+ # in dataset_new.
379
+ valid_frame_index = [
380
+ fi for fi, di in zip(frame_index, dataset_indices) if di is not None
381
+ ]
382
+ dataset_new.seq_frame_index_to_dataset_index(
383
+ [valid_frame_index], allow_missing_indices=False
384
+ )
385
+
386
+ return dataset_new
387
+
388
+ def __str__(self) -> str:
389
+ # pyre-ignore[16]
390
+ return f"JsonIndexDataset #frames={len(self.frame_annots)}"
391
+
392
+ def __len__(self) -> int:
393
+ # pyre-ignore[16]
394
+ return len(self.frame_annots)
395
+
396
+ def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]:
397
+ return entry["subset"]
398
+
399
+ def get_all_train_cameras(self) -> CamerasBase:
400
+ """
401
+ Returns the cameras corresponding to all the known frames.
402
+ """
403
+ logger.info("Loading all train cameras.")
404
+ cameras = []
405
+ # pyre-ignore[16]
406
+ for frame_idx, frame_annot in enumerate(tqdm(self.frame_annots)):
407
+ frame_type = self._get_frame_type(frame_annot)
408
+ if frame_type is None:
409
+ raise ValueError("subsets not loaded")
410
+ if is_known_frame_scalar(frame_type):
411
+ cameras.append(self[frame_idx].camera)
412
+ return join_cameras_as_batch(cameras)
413
+
414
+ def __getitem__(self, index) -> FrameData:
415
+ # pyre-ignore[16]
416
+ if index >= len(self.frame_annots):
417
+ raise IndexError(f"index {index} out of range {len(self.frame_annots)}")
418
+
419
+ entry = self.frame_annots[index]["frame_annotation"]
420
+ # pyre-ignore[16]
421
+ point_cloud = self.seq_annots[entry.sequence_name].point_cloud
422
+ frame_data = FrameData(
423
+ frame_number=_safe_as_tensor(entry.frame_number, torch.long),
424
+ frame_timestamp=_safe_as_tensor(entry.frame_timestamp, torch.float),
425
+ sequence_name=entry.sequence_name,
426
+ sequence_category=self.seq_annots[entry.sequence_name].category,
427
+ camera_quality_score=_safe_as_tensor(
428
+ self.seq_annots[entry.sequence_name].viewpoint_quality_score,
429
+ torch.float,
430
+ ),
431
+ point_cloud_quality_score=_safe_as_tensor(
432
+ point_cloud.quality_score, torch.float
433
+ )
434
+ if point_cloud is not None
435
+ else None,
436
+ )
437
+
438
+ # The rest of the fields are optional
439
+ frame_data.frame_type = self._get_frame_type(self.frame_annots[index])
440
+
441
+ (
442
+ frame_data.fg_probability,
443
+ frame_data.mask_path,
444
+ frame_data.bbox_xywh,
445
+ clamp_bbox_xyxy,
446
+ frame_data.crop_bbox_xywh,
447
+ ) = self._load_crop_fg_probability(entry)
448
+
449
+ scale = 1.0
450
+ if self.load_images and entry.image is not None:
451
+ # original image size
452
+ frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long)
453
+
454
+ (
455
+ frame_data.image_rgb,
456
+ frame_data.image_path,
457
+ frame_data.mask_crop,
458
+ scale,
459
+ ) = self._load_crop_images(
460
+ entry, frame_data.fg_probability, clamp_bbox_xyxy
461
+ )
462
+
463
+ if self.load_depths and entry.depth is not None:
464
+ (
465
+ frame_data.depth_map,
466
+ frame_data.depth_path,
467
+ frame_data.depth_mask,
468
+ ) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability)
469
+
470
+ if entry.viewpoint is not None:
471
+ frame_data.camera = self._get_pytorch3d_camera(
472
+ entry,
473
+ scale,
474
+ clamp_bbox_xyxy,
475
+ )
476
+
477
+ if self.load_point_clouds and point_cloud is not None:
478
+ pcl_path = self._fix_point_cloud_path(point_cloud.path)
479
+ frame_data.sequence_point_cloud = _load_pointcloud(
480
+ self._local_path(pcl_path), max_points=self.max_points
481
+ )
482
+ frame_data.sequence_point_cloud_path = pcl_path
483
+
484
+ return frame_data
485
+
486
+ def _fix_point_cloud_path(self, path: str) -> str:
487
+ """
488
+ Fix up a point cloud path from the dataset.
489
+ Some files in Co3Dv2 have an accidental absolute path stored.
490
+ """
491
+ unwanted_prefix = (
492
+ "/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/"
493
+ )
494
+ if path.startswith(unwanted_prefix):
495
+ path = path[len(unwanted_prefix) :]
496
+ return os.path.join(self.dataset_root, path)
497
+
498
+ def _load_crop_fg_probability(
499
+ self, entry: types.FrameAnnotation
500
+ ) -> Tuple[
501
+ Optional[torch.Tensor],
502
+ Optional[str],
503
+ Optional[torch.Tensor],
504
+ Optional[torch.Tensor],
505
+ Optional[torch.Tensor],
506
+ ]:
507
+ fg_probability = None
508
+ full_path = None
509
+ bbox_xywh = None
510
+ clamp_bbox_xyxy = None
511
+ crop_box_xywh = None
512
+
513
+ if (self.load_masks or self.box_crop) and entry.mask is not None:
514
+ full_path = os.path.join(self.dataset_root, entry.mask.path)
515
+ mask = _load_mask(self._local_path(full_path))
516
+
517
+ if mask.shape[-2:] != entry.image.size:
518
+ raise ValueError(
519
+ f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!"
520
+ )
521
+
522
+ bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr))
523
+
524
+ if self.box_crop:
525
+ clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round(
526
+ _get_clamp_bbox(
527
+ bbox_xywh,
528
+ image_path=entry.image.path,
529
+ box_crop_context=self.box_crop_context,
530
+ ),
531
+ image_size_hw=tuple(mask.shape[-2:]),
532
+ )
533
+ crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy)
534
+
535
+ mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path)
536
+
537
+ fg_probability, _, _ = self._resize_image(mask, mode="nearest")
538
+
539
+ return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh
540
+
541
+ def _load_crop_images(
542
+ self,
543
+ entry: types.FrameAnnotation,
544
+ fg_probability: Optional[torch.Tensor],
545
+ clamp_bbox_xyxy: Optional[torch.Tensor],
546
+ ) -> Tuple[torch.Tensor, str, torch.Tensor, float]:
547
+ assert self.dataset_root is not None and entry.image is not None
548
+ path = os.path.join(self.dataset_root, entry.image.path)
549
+ image_rgb = _load_image(self._local_path(path))
550
+
551
+ if image_rgb.shape[-2:] != entry.image.size:
552
+ raise ValueError(
553
+ f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
554
+ )
555
+
556
+ if self.box_crop:
557
+ assert clamp_bbox_xyxy is not None
558
+ image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path)
559
+
560
+ image_rgb, scale, mask_crop = self._resize_image(image_rgb)
561
+
562
+ if self.mask_images:
563
+ assert fg_probability is not None
564
+ image_rgb *= fg_probability
565
+
566
+ return image_rgb, path, mask_crop, scale
567
+
568
+ def _load_mask_depth(
569
+ self,
570
+ entry: types.FrameAnnotation,
571
+ clamp_bbox_xyxy: Optional[torch.Tensor],
572
+ fg_probability: Optional[torch.Tensor],
573
+ ) -> Tuple[torch.Tensor, str, torch.Tensor]:
574
+ entry_depth = entry.depth
575
+ assert entry_depth is not None
576
+ path = os.path.join(self.dataset_root, entry_depth.path)
577
+ depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment)
578
+
579
+ if self.box_crop:
580
+ assert clamp_bbox_xyxy is not None
581
+ depth_bbox_xyxy = _rescale_bbox(
582
+ clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:]
583
+ )
584
+ depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path)
585
+
586
+ depth_map, _, _ = self._resize_image(depth_map, mode="nearest")
587
+
588
+ if self.mask_depths:
589
+ assert fg_probability is not None
590
+ depth_map *= fg_probability
591
+
592
+ if self.load_depth_masks:
593
+ assert entry_depth.mask_path is not None
594
+ mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
595
+ depth_mask = _load_depth_mask(self._local_path(mask_path))
596
+
597
+ if self.box_crop:
598
+ assert clamp_bbox_xyxy is not None
599
+ depth_mask_bbox_xyxy = _rescale_bbox(
600
+ clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:]
601
+ )
602
+ depth_mask = _crop_around_box(
603
+ depth_mask, depth_mask_bbox_xyxy, mask_path
604
+ )
605
+
606
+ depth_mask, _, _ = self._resize_image(depth_mask, mode="nearest")
607
+ else:
608
+ depth_mask = torch.ones_like(depth_map)
609
+
610
+ return depth_map, path, depth_mask
611
+
612
+ def _get_pytorch3d_camera(
613
+ self,
614
+ entry: types.FrameAnnotation,
615
+ scale: float,
616
+ clamp_bbox_xyxy: Optional[torch.Tensor],
617
+ ) -> PerspectiveCameras:
618
+ entry_viewpoint = entry.viewpoint
619
+ assert entry_viewpoint is not None
620
+ # principal point and focal length
621
+ principal_point = torch.tensor(
622
+ entry_viewpoint.principal_point, dtype=torch.float
623
+ )
624
+ focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)
625
+
626
+ half_image_size_wh_orig = (
627
+ torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0
628
+ )
629
+
630
+ # first, we convert from the dataset's NDC convention to pixels
631
+ format = entry_viewpoint.intrinsics_format
632
+ if format.lower() == "ndc_norm_image_bounds":
633
+ # this is e.g. currently used in CO3D for storing intrinsics
634
+ rescale = half_image_size_wh_orig
635
+ elif format.lower() == "ndc_isotropic":
636
+ rescale = half_image_size_wh_orig.min()
637
+ else:
638
+ raise ValueError(f"Unknown intrinsics format: {format}")
639
+
640
+ # principal point and focal length in pixels
641
+ principal_point_px = half_image_size_wh_orig - principal_point * rescale
642
+ focal_length_px = focal_length * rescale
643
+ if self.box_crop:
644
+ assert clamp_bbox_xyxy is not None
645
+ principal_point_px -= clamp_bbox_xyxy[:2]
646
+
647
+ # now, convert from pixels to PyTorch3D v0.5+ NDC convention
648
+ if self.image_height is None or self.image_width is None:
649
+ out_size = list(reversed(entry.image.size))
650
+ else:
651
+ out_size = [self.image_width, self.image_height]
652
+
653
+ half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0
654
+ half_min_image_size_output = half_image_size_output.min()
655
+
656
+ # rescaled principal point and focal length in ndc
657
+ principal_point = (
658
+ half_image_size_output - principal_point_px * scale
659
+ ) / half_min_image_size_output
660
+ focal_length = focal_length_px * scale / half_min_image_size_output
661
+
662
+ return PerspectiveCameras(
663
+ focal_length=focal_length[None],
664
+ principal_point=principal_point[None],
665
+ R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],
666
+ T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
667
+ )
668
+
669
+ def _load_frames(self) -> None:
670
+ logger.info(f"Loading Co3D frames from {self.frame_annotations_file}.")
671
+ local_file = self._local_path(self.frame_annotations_file)
672
+ with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
673
+ frame_annots_list = types.load_dataclass(
674
+ zipfile, List[self.frame_annotations_type]
675
+ )
676
+ if not frame_annots_list:
677
+ raise ValueError("Empty dataset!")
678
+ # pyre-ignore[16]
679
+ self.frame_annots = [
680
+ FrameAnnotsEntry(frame_annotation=a, subset=None) for a in frame_annots_list
681
+ ]
682
+
683
+ def _load_sequences(self) -> None:
684
+ logger.info(f"Loading Co3D sequences from {self.sequence_annotations_file}.")
685
+ local_file = self._local_path(self.sequence_annotations_file)
686
+ with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
687
+ seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation])
688
+ if not seq_annots:
689
+ raise ValueError("Empty sequences file!")
690
+ # pyre-ignore[16]
691
+ self.seq_annots = {entry.sequence_name: entry for entry in seq_annots}
692
+
693
+ def _load_subset_lists(self) -> None:
694
+ logger.info(f"Loading Co3D subset lists from {self.subset_lists_file}.")
695
+ if not self.subset_lists_file:
696
+ return
697
+
698
+ with open(self._local_path(self.subset_lists_file), "r") as f:
699
+ subset_to_seq_frame = json.load(f)
700
+
701
+ frame_path_to_subset = {
702
+ path: subset
703
+ for subset, frames in subset_to_seq_frame.items()
704
+ for _, _, path in frames
705
+ }
706
+ # pyre-ignore[16]
707
+ for frame in self.frame_annots:
708
+ frame["subset"] = frame_path_to_subset.get(
709
+ frame["frame_annotation"].image.path, None
710
+ )
711
+ if frame["subset"] is None:
712
+ warnings.warn(
713
+ "Subset lists are given but don't include "
714
+ + frame["frame_annotation"].image.path
715
+ )
716
+
717
+ def _sort_frames(self) -> None:
718
+ # Sort frames to have them grouped by sequence, ordered by timestamp
719
+ # pyre-ignore[16]
720
+ self.frame_annots = sorted(
721
+ self.frame_annots,
722
+ key=lambda f: (
723
+ f["frame_annotation"].sequence_name,
724
+ f["frame_annotation"].frame_timestamp or 0,
725
+ ),
726
+ )
727
+
728
+ def _filter_db(self) -> None:
729
+ if self.remove_empty_masks:
730
+ logger.info("Removing images with empty masks.")
731
+ # pyre-ignore[16]
732
+ old_len = len(self.frame_annots)
733
+
734
+ msg = "remove_empty_masks needs every MaskAnnotation.mass to be set."
735
+
736
+ def positive_mass(frame_annot: types.FrameAnnotation) -> bool:
737
+ mask = frame_annot.mask
738
+ if mask is None:
739
+ return False
740
+ if mask.mass is None:
741
+ raise ValueError(msg)
742
+ return mask.mass > 1
743
+
744
+ self.frame_annots = [
745
+ frame
746
+ for frame in self.frame_annots
747
+ if positive_mass(frame["frame_annotation"])
748
+ ]
749
+ logger.info("... filtered %d -> %d" % (old_len, len(self.frame_annots)))
750
+
751
+ # this has to be called after joining with categories!!
752
+ subsets = self.subsets
753
+ if subsets:
754
+ if not self.subset_lists_file:
755
+ raise ValueError(
756
+ "Subset filter is on but subset_lists_file was not given"
757
+ )
758
+
759
+ logger.info(f"Limiting Co3D dataset to the '{subsets}' subsets.")
760
+
761
+ # truncate the list of subsets to the valid one
762
+ self.frame_annots = [
763
+ entry for entry in self.frame_annots if entry["subset"] in subsets
764
+ ]
765
+ if len(self.frame_annots) == 0:
766
+ raise ValueError(f"There are no frames in the '{subsets}' subsets!")
767
+
768
+ self._invalidate_indexes(filter_seq_annots=True)
769
+
770
+ if len(self.limit_category_to) > 0:
771
+ logger.info(f"Limiting dataset to categories: {self.limit_category_to}")
772
+ # pyre-ignore[16]
773
+ self.seq_annots = {
774
+ name: entry
775
+ for name, entry in self.seq_annots.items()
776
+ if entry.category in self.limit_category_to
777
+ }
778
+
779
+ # sequence filters
780
+ for prefix in ("pick", "exclude"):
781
+ orig_len = len(self.seq_annots)
782
+ attr = f"{prefix}_sequence"
783
+ arr = getattr(self, attr)
784
+ if len(arr) > 0:
785
+ logger.info(f"{attr}: {str(arr)}")
786
+ self.seq_annots = {
787
+ name: entry
788
+ for name, entry in self.seq_annots.items()
789
+ if (name in arr) == (prefix == "pick")
790
+ }
791
+ logger.info("... filtered %d -> %d" % (orig_len, len(self.seq_annots)))
792
+
793
+ if self.limit_sequences_to > 0:
794
+ self.seq_annots = dict(
795
+ islice(self.seq_annots.items(), self.limit_sequences_to)
796
+ )
797
+
798
+ # retain only frames from retained sequences
799
+ self.frame_annots = [
800
+ f
801
+ for f in self.frame_annots
802
+ if f["frame_annotation"].sequence_name in self.seq_annots
803
+ ]
804
+
805
+ self._invalidate_indexes()
806
+
807
+ if self.n_frames_per_sequence > 0:
808
+ logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.")
809
+ keep_idx = []
810
+ # pyre-ignore[16]
811
+ for seq, seq_indices in self._seq_to_idx.items():
812
+ # infer the seed from the sequence name, this is reproducible
813
+ # and makes the selection differ for different sequences
814
+ seed = _seq_name_to_seed(seq) + self.seed
815
+ seq_idx_shuffled = random.Random(seed).sample(
816
+ sorted(seq_indices), len(seq_indices)
817
+ )
818
+ keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence])
819
+
820
+ logger.info(
821
+ "... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx))
822
+ )
823
+ self.frame_annots = [self.frame_annots[i] for i in keep_idx]
824
+ self._invalidate_indexes(filter_seq_annots=False)
825
+ # sequences are not decimated, so self.seq_annots is valid
826
+
827
+ if self.limit_to > 0 and self.limit_to < len(self.frame_annots):
828
+ logger.info(
829
+ "limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to)
830
+ )
831
+ self.frame_annots = self.frame_annots[: self.limit_to]
832
+ self._invalidate_indexes(filter_seq_annots=True)
833
+
834
+ def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None:
835
+ # update _seq_to_idx and filter seq_meta according to frame_annots change
836
+ # if filter_seq_annots, also uldates seq_annots based on the changed _seq_to_idx
837
+ self._invalidate_seq_to_idx()
838
+
839
+ if filter_seq_annots:
840
+ # pyre-ignore[16]
841
+ self.seq_annots = {
842
+ k: v
843
+ for k, v in self.seq_annots.items()
844
+ # pyre-ignore[16]
845
+ if k in self._seq_to_idx
846
+ }
847
+
848
+ def _invalidate_seq_to_idx(self) -> None:
849
+ seq_to_idx = defaultdict(list)
850
+ # pyre-ignore[16]
851
+ for idx, entry in enumerate(self.frame_annots):
852
+ seq_to_idx[entry["frame_annotation"].sequence_name].append(idx)
853
+ # pyre-ignore[16]
854
+ self._seq_to_idx = seq_to_idx
855
+
856
+ def _resize_image(
857
+ self, image, mode="bilinear"
858
+ ) -> Tuple[torch.Tensor, float, torch.Tensor]:
859
+ image_height, image_width = self.image_height, self.image_width
860
+ if image_height is None or image_width is None:
861
+ # skip the resizing
862
+ imre_ = torch.from_numpy(image)
863
+ return imre_, 1.0, torch.ones_like(imre_[:1])
864
+ # takes numpy array, returns pytorch tensor
865
+ minscale = min(
866
+ image_height / image.shape[-2],
867
+ image_width / image.shape[-1],
868
+ )
869
+ imre = torch.nn.functional.interpolate(
870
+ torch.from_numpy(image)[None],
871
+ scale_factor=minscale,
872
+ mode=mode,
873
+ align_corners=False if mode == "bilinear" else None,
874
+ recompute_scale_factor=True,
875
+ )[0]
876
+ # pyre-fixme[19]: Expected 1 positional argument.
877
+ imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width)
878
+ imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
879
+ # pyre-fixme[6]: For 2nd param expected `int` but got `Optional[int]`.
880
+ # pyre-fixme[6]: For 3rd param expected `int` but got `Optional[int]`.
881
+ mask = torch.zeros(1, self.image_height, self.image_width)
882
+ mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0
883
+ return imre_, minscale, mask
884
+
885
+ def _local_path(self, path: str) -> str:
886
+ if self.path_manager is None:
887
+ return path
888
+ return self.path_manager.get_local_path(path)
889
+
890
+ def get_frame_numbers_and_timestamps(
891
+ self, idxs: Sequence[int]
892
+ ) -> List[Tuple[int, float]]:
893
+ out: List[Tuple[int, float]] = []
894
+ for idx in idxs:
895
+ # pyre-ignore[16]
896
+ frame_annotation = self.frame_annots[idx]["frame_annotation"]
897
+ out.append(
898
+ (frame_annotation.frame_number, frame_annotation.frame_timestamp)
899
+ )
900
+ return out
901
+
902
+ def category_to_sequence_names(self) -> Dict[str, List[str]]:
903
+ c2seq = defaultdict(list)
904
+ # pyre-ignore
905
+ for sequence_name, sa in self.seq_annots.items():
906
+ c2seq[sa.category].append(sequence_name)
907
+ return dict(c2seq)
908
+
909
+ def get_eval_batches(self) -> Optional[List[List[int]]]:
910
+ return self.eval_batches
911
+
912
+
913
+ def _seq_name_to_seed(seq_name) -> int:
914
+ return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest(), 16)
915
+
916
+
917
+ def _load_image(path) -> np.ndarray:
918
+ with Image.open(path) as pil_im:
919
+ im = np.array(pil_im.convert("RGB"))
920
+ im = im.transpose((2, 0, 1))
921
+ im = im.astype(np.float32) / 255.0
922
+ return im
923
+
924
+
925
+ def _load_16big_png_depth(depth_png) -> np.ndarray:
926
+ with Image.open(depth_png) as depth_pil:
927
+ # the image is stored with 16-bit depth but PIL reads it as I (32 bit).
928
+ # we cast it to uint16, then reinterpret as float16, then cast to float32
929
+ depth = (
930
+ np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
931
+ .astype(np.float32)
932
+ .reshape((depth_pil.size[1], depth_pil.size[0]))
933
+ )
934
+ return depth
935
+
936
+
937
+ def _load_1bit_png_mask(file: str) -> np.ndarray:
938
+ with Image.open(file) as pil_im:
939
+ mask = (np.array(pil_im.convert("L")) > 0.0).astype(np.float32)
940
+ return mask
941
+
942
+
943
+ def _load_depth_mask(path: str) -> np.ndarray:
944
+ if not path.lower().endswith(".png"):
945
+ raise ValueError('unsupported depth mask file name "%s"' % path)
946
+ m = _load_1bit_png_mask(path)
947
+ return m[None] # fake feature channel
948
+
949
+
950
+ def _load_depth(path, scale_adjustment) -> np.ndarray:
951
+ if not path.lower().endswith(".png"):
952
+ raise ValueError('unsupported depth file name "%s"' % path)
953
+
954
+ d = _load_16big_png_depth(path) * scale_adjustment
955
+ d[~np.isfinite(d)] = 0.0
956
+ return d[None] # fake feature channel
957
+
958
+
959
+ def _load_mask(path) -> np.ndarray:
960
+ with Image.open(path) as pil_im:
961
+ mask = np.array(pil_im)
962
+ mask = mask.astype(np.float32) / 255.0
963
+ return mask[None] # fake feature channel
964
+
965
+
966
+ def _get_1d_bounds(arr) -> Tuple[int, int]:
967
+ nz = np.flatnonzero(arr)
968
+ return nz[0], nz[-1] + 1
969
+
970
+
971
+ def _get_bbox_from_mask(
972
+ mask, thr, decrease_quant: float = 0.05
973
+ ) -> Tuple[int, int, int, int]:
974
+ # bbox in xywh
975
+ masks_for_box = np.zeros_like(mask)
976
+ while masks_for_box.sum() <= 1.0:
977
+ masks_for_box = (mask > thr).astype(np.float32)
978
+ thr -= decrease_quant
979
+ if thr <= 0.0:
980
+ warnings.warn(f"Empty masks_for_bbox (thr={thr}) => using full image.")
981
+
982
+ x0, x1 = _get_1d_bounds(masks_for_box.sum(axis=-2))
983
+ y0, y1 = _get_1d_bounds(masks_for_box.sum(axis=-1))
984
+
985
+ return x0, y0, x1 - x0, y1 - y0
986
+
987
+
988
+ def _get_clamp_bbox(
989
+ bbox: torch.Tensor,
990
+ box_crop_context: float = 0.0,
991
+ image_path: str = "",
992
+ ) -> torch.Tensor:
993
+ # box_crop_context: rate of expansion for bbox
994
+ # returns possibly expanded bbox xyxy as float
995
+
996
+ bbox = bbox.clone() # do not edit bbox in place
997
+
998
+ # increase box size
999
+ if box_crop_context > 0.0:
1000
+ c = box_crop_context
1001
+ bbox = bbox.float()
1002
+ bbox[0] -= bbox[2] * c / 2
1003
+ bbox[1] -= bbox[3] * c / 2
1004
+ bbox[2] += bbox[2] * c
1005
+ bbox[3] += bbox[3] * c
1006
+
1007
+ if (bbox[2:] <= 1.0).any():
1008
+ raise ValueError(
1009
+ f"squashed image {image_path}!! The bounding box contains no pixels."
1010
+ )
1011
+
1012
+ bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes
1013
+ bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2)
1014
+
1015
+ return bbox_xyxy
1016
+
1017
+
1018
+ def _crop_around_box(tensor, bbox, impath: str = ""):
1019
+ # bbox is xyxy, where the upper bound is corrected with +1
1020
+ bbox = _clamp_box_to_image_bounds_and_round(
1021
+ bbox,
1022
+ image_size_hw=tensor.shape[-2:],
1023
+ )
1024
+ tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]]
1025
+ assert all(c > 0 for c in tensor.shape), f"squashed image {impath}"
1026
+ return tensor
1027
+
1028
+
1029
+ def _clamp_box_to_image_bounds_and_round(
1030
+ bbox_xyxy: torch.Tensor,
1031
+ image_size_hw: Tuple[int, int],
1032
+ ) -> torch.LongTensor:
1033
+ bbox_xyxy = bbox_xyxy.clone()
1034
+ bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1])
1035
+ bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2])
1036
+ if not isinstance(bbox_xyxy, torch.LongTensor):
1037
+ bbox_xyxy = bbox_xyxy.round().long()
1038
+ return bbox_xyxy # pyre-ignore [7]
1039
+
1040
+
1041
+ def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor:
1042
+ assert bbox is not None
1043
+ assert np.prod(orig_res) > 1e-8
1044
+ # average ratio of dimensions
1045
+ rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0
1046
+ return bbox * rel_size
1047
+
1048
+
1049
+ def _bbox_xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor:
1050
+ wh = xyxy[2:] - xyxy[:2]
1051
+ xywh = torch.cat([xyxy[:2], wh])
1052
+ return xywh
1053
+
1054
+
1055
+ def _bbox_xywh_to_xyxy(
1056
+ xywh: torch.Tensor, clamp_size: Optional[int] = None
1057
+ ) -> torch.Tensor:
1058
+ xyxy = xywh.clone()
1059
+ if clamp_size is not None:
1060
+ xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
1061
+ xyxy[2:] += xyxy[:2]
1062
+ return xyxy
1063
+
1064
+
1065
+ def _safe_as_tensor(data, dtype):
1066
+ if data is None:
1067
+ return None
1068
+ return torch.tensor(data, dtype=dtype)
1069
+
1070
+
1071
+ # NOTE this cache is per-worker; they are implemented as processes.
1072
+ # each batch is loaded and collated by a single worker;
1073
+ # since sequences tend to co-occur within batches, this is useful.
1074
+ @functools.lru_cache(maxsize=256)
1075
+ def _load_pointcloud(pcl_path: Union[str, Path], max_points: int = 0) -> Pointclouds:
1076
+ pcl = IO().load_pointcloud(pcl_path)
1077
+ if max_points > 0:
1078
+ pcl = pcl.subsample(max_points)
1079
+
1080
+ return pcl
sgm/data/latent_objaverse.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from pathlib import Path
3
+ from PIL import Image
4
+ import json
5
+ import torch
6
+ from torch.utils.data import Dataset, DataLoader, default_collate
7
+ from torchvision.transforms import ToTensor, Normalize, Compose, Resize
8
+ from pytorch_lightning import LightningDataModule
9
+ from einops import rearrange
10
+
11
+
12
+ class LatentObjaverseSpiral(Dataset):
13
+ def __init__(
14
+ self,
15
+ root_dir,
16
+ split="train",
17
+ transform=None,
18
+ random_front=False,
19
+ max_item=None,
20
+ cond_aug_mean=-3.0,
21
+ cond_aug_std=0.5,
22
+ condition_on_elevation=False,
23
+ **unused_kwargs,
24
+ ):
25
+ print("Using LVIS subset with precomputed Latents")
26
+ self.root_dir = Path(root_dir)
27
+ self.split = split
28
+ self.random_front = random_front
29
+ self.transform = transform
30
+
31
+ self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512")
32
+
33
+ self.ids = json.load(open("./assets/lvis_uids.json", "r"))
34
+ self.n_views = 18
35
+ valid_ids = []
36
+ for idx in self.ids:
37
+ if (self.root_dir / idx).exists():
38
+ valid_ids.append(idx)
39
+ self.ids = valid_ids
40
+ print("=" * 30)
41
+ print("Number of valid ids: ", len(self.ids))
42
+ print("=" * 30)
43
+
44
+ self.cond_aug_mean = cond_aug_mean
45
+ self.cond_aug_std = cond_aug_std
46
+ self.condition_on_elevation = condition_on_elevation
47
+
48
+ if max_item is not None:
49
+ self.ids = self.ids[:max_item]
50
+
51
+ ## debug
52
+ self.ids = self.ids * 10000
sgm/data/mnist.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torchvision
3
+ from torch.utils.data import DataLoader, Dataset
4
+ from torchvision import transforms
5
+
6
+
7
+ class MNISTDataDictWrapper(Dataset):
8
+ def __init__(self, dset):
9
+ super().__init__()
10
+ self.dset = dset
11
+
12
+ def __getitem__(self, i):
13
+ x, y = self.dset[i]
14
+ return {"jpg": x, "cls": y}
15
+
16
+ def __len__(self):
17
+ return len(self.dset)
18
+
19
+
20
+ class MNISTLoader(pl.LightningDataModule):
21
+ def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True):
22
+ super().__init__()
23
+
24
+ transform = transforms.Compose(
25
+ [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
26
+ )
27
+
28
+ self.batch_size = batch_size
29
+ self.num_workers = num_workers
30
+ self.prefetch_factor = prefetch_factor if num_workers > 0 else 0
31
+ self.shuffle = shuffle
32
+ self.train_dataset = MNISTDataDictWrapper(
33
+ torchvision.datasets.MNIST(
34
+ root=".data/", train=True, download=True, transform=transform
35
+ )
36
+ )
37
+ self.test_dataset = MNISTDataDictWrapper(
38
+ torchvision.datasets.MNIST(
39
+ root=".data/", train=False, download=True, transform=transform
40
+ )
41
+ )
42
+
43
+ def prepare_data(self):
44
+ pass
45
+
46
+ def train_dataloader(self):
47
+ return DataLoader(
48
+ self.train_dataset,
49
+ batch_size=self.batch_size,
50
+ shuffle=self.shuffle,
51
+ num_workers=self.num_workers,
52
+ prefetch_factor=self.prefetch_factor,
53
+ )
54
+
55
+ def test_dataloader(self):
56
+ return DataLoader(
57
+ self.test_dataset,
58
+ batch_size=self.batch_size,
59
+ shuffle=self.shuffle,
60
+ num_workers=self.num_workers,
61
+ prefetch_factor=self.prefetch_factor,
62
+ )
63
+
64
+ def val_dataloader(self):
65
+ return DataLoader(
66
+ self.test_dataset,
67
+ batch_size=self.batch_size,
68
+ shuffle=self.shuffle,
69
+ num_workers=self.num_workers,
70
+ prefetch_factor=self.prefetch_factor,
71
+ )
72
+
73
+
74
+ if __name__ == "__main__":
75
+ dset = MNISTDataDictWrapper(
76
+ torchvision.datasets.MNIST(
77
+ root=".data/",
78
+ train=False,
79
+ download=True,
80
+ transform=transforms.Compose(
81
+ [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
82
+ ),
83
+ )
84
+ )
85
+ ex = dset[0]
sgm/data/mvimagenet.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.utils.data import Dataset, DataLoader, default_collate
4
+ from pathlib import Path
5
+ from PIL import Image
6
+ from scipy.spatial.transform import Rotation
7
+ import rembg
8
+ from rembg import remove, new_session
9
+ from einops import rearrange
10
+
11
+ from torchvision.transforms import ToTensor, Normalize, Compose, Resize
12
+ from torchvision.transforms.functional import to_tensor
13
+ from pytorch_lightning import LightningDataModule
14
+
15
+ from sgm.data.colmap import read_cameras_binary, read_images_binary
16
+ from sgm.data.objaverse import video_collate_fn, FLATTEN_FIELDS, flatten_for_video
17
+
18
+
19
+ def qvec2rotmat(qvec):
20
+ return np.array(
21
+ [
22
+ [
23
+ 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
24
+ 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
25
+ 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],
26
+ ],
27
+ [
28
+ 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
29
+ 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
30
+ 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],
31
+ ],
32
+ [
33
+ 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
34
+ 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
35
+ 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,
36
+ ],
37
+ ]
38
+ )
39
+
40
+
41
+ def qt2c2w(q, t):
42
+ # NOTE: remember to convert to opengl coordinate system
43
+ # rot = Rotation.from_quat(q).as_matrix()
44
+ rot = qvec2rotmat(q)
45
+ c2w = np.eye(4)
46
+ c2w[:3, :3] = np.transpose(rot)
47
+ c2w[:3, 3] = -np.transpose(rot) @ t
48
+ c2w[..., 1:3] *= -1
49
+ return c2w
50
+
51
+
52
+ def random_crop():
53
+ pass
54
+
55
+
56
+ class MVImageNet(Dataset):
57
+ def __init__(
58
+ self,
59
+ root_dir,
60
+ split,
61
+ transform,
62
+ reso: int = 256,
63
+ mask_type: str = "random",
64
+ cond_aug_mean=-3.0,
65
+ cond_aug_std=0.5,
66
+ condition_on_elevation=False,
67
+ fps_id=0.0,
68
+ motion_bucket_id=300.0,
69
+ num_frames: int = 24,
70
+ use_mask: bool = True,
71
+ load_pixelnerf: bool = False,
72
+ scale_pose: bool = False,
73
+ max_n_cond: int = 1,
74
+ min_n_cond: int = 1,
75
+ cond_on_multi: bool = False,
76
+ ) -> None:
77
+ super().__init__()
78
+
79
+ self.root_dir = Path(root_dir)
80
+ self.split = split
81
+
82
+ avails = self.root_dir.glob("*/*")
83
+ self.ids = list(
84
+ map(
85
+ lambda x: str(x.relative_to(self.root_dir)),
86
+ filter(lambda x: x.is_dir(), avails),
87
+ )
88
+ )
89
+
90
+ self.transform = transform
91
+ self.reso = reso
92
+ self.num_frames = num_frames
93
+ self.cond_aug_mean = cond_aug_mean
94
+ self.cond_aug_std = cond_aug_std
95
+ self.condition_on_elevation = condition_on_elevation
96
+ self.fps_id = fps_id
97
+ self.motion_bucket_id = motion_bucket_id
98
+ self.mask_type = mask_type
99
+ self.use_mask = use_mask
100
+ self.load_pixelnerf = load_pixelnerf
101
+ self.scale_pose = scale_pose
102
+ self.max_n_cond = max_n_cond
103
+ self.min_n_cond = min_n_cond
104
+ self.cond_on_multi = cond_on_multi
105
+
106
+ if self.cond_on_multi:
107
+ assert self.min_n_cond == self.max_n_cond
108
+ self.session = new_session()
109
+
110
+ def __getitem__(self, index: int):
111
+ # mvimgnet starts with idx==1
112
+ idx_list = np.arange(0, self.num_frames)
113
+ this_image_dir = self.root_dir / self.ids[index] / "images"
114
+ this_camera_dir = self.root_dir / self.ids[index] / "sparse/0"
115
+
116
+ # while not this_camera_dir.exists():
117
+ # index = (index + 1) % len(self.ids)
118
+ # this_image_dir = self.root_dir / self.ids[index] / "images"
119
+ # this_camera_dir = self.root_dir / self.ids[index] / "sparse/0"
120
+ if not this_camera_dir.exists():
121
+ index = 0
122
+ this_image_dir = self.root_dir / self.ids[index] / "images"
123
+ this_camera_dir = self.root_dir / self.ids[index] / "sparse/0"
124
+
125
+ this_images = read_images_binary(this_camera_dir / "images.bin")
126
+ # filenames = list(map(lambda x: f"{x:03d}", this_images.keys()))
127
+ filenames = list(this_images.keys())
128
+
129
+ if len(filenames) == 0:
130
+ index = 0
131
+ this_image_dir = self.root_dir / self.ids[index] / "images"
132
+ this_camera_dir = self.root_dir / self.ids[index] / "sparse/0"
133
+ this_images = read_images_binary(this_camera_dir / "images.bin")
134
+ # filenames = list(map(lambda x: f"{x:03d}", this_images.keys()))
135
+ filenames = list(this_images.keys())
136
+
137
+ filenames = list(
138
+ filter(lambda x: (this_image_dir / this_images[x].name).exists(), filenames)
139
+ )
140
+
141
+ filenames = sorted(filenames, key=lambda x: this_images[x].name)
142
+
143
+ # # debug
144
+ # names = []
145
+ # for v in filenames:
146
+ # names.append(this_images[v].name)
147
+ # breakpoint()
148
+
149
+ while len(filenames) < self.num_frames:
150
+ num_surpass = self.num_frames - len(filenames)
151
+ filenames += list(reversed(filenames[-num_surpass:]))
152
+
153
+ if len(filenames) < self.num_frames:
154
+ print(f"\n\n{self.ids[index]}\n\n")
155
+
156
+ frames = []
157
+ cameras = []
158
+ downsampled_rgb = []
159
+ for view_idx in idx_list:
160
+ this_id = filenames[view_idx]
161
+ frame = Image.open(this_image_dir / this_images[this_id].name)
162
+ w, h = frame.size
163
+
164
+ if self.mask_type == "random":
165
+ image_size = min(h, w)
166
+ left = np.random.randint(0, w - image_size + 1)
167
+ right = left + image_size
168
+ top = np.random.randint(0, h - image_size + 1)
169
+ bottom = top + image_size
170
+ ## need to assign left, right, top, bottom, image_size
171
+ elif self.mask_type == "object":
172
+ pass
173
+ elif self.mask_type == "rembg":
174
+ image_size = min(h, w)
175
+ if (
176
+ cached := this_image_dir
177
+ / f"{this_images[this_id].name[:-4]}_rembg.png"
178
+ ).exists():
179
+ try:
180
+ mask = np.asarray(Image.open(cached, formats=["png"]))[..., 3]
181
+ except:
182
+ mask = remove(frame, session=self.session)
183
+ mask.save(cached)
184
+ mask = np.asarray(mask)[..., 3]
185
+ else:
186
+ mask = remove(frame, session=self.session)
187
+ mask.save(cached)
188
+ mask = np.asarray(mask)[..., 3]
189
+ # in h,w order
190
+ y, x = np.array(mask.nonzero())
191
+ bbox_cx = x.mean()
192
+ bbox_cy = y.mean()
193
+
194
+ if bbox_cy - image_size / 2 < 0:
195
+ top = 0
196
+ elif bbox_cy + image_size / 2 > h:
197
+ top = h - image_size
198
+ else:
199
+ top = int(bbox_cy - image_size / 2)
200
+
201
+ if bbox_cx - image_size / 2 < 0:
202
+ left = 0
203
+ elif bbox_cx + image_size / 2 > w:
204
+ left = w - image_size
205
+ else:
206
+ left = int(bbox_cx - image_size / 2)
207
+
208
+ # top = max(int(bbox_cy - image_size / 2), 0)
209
+ # left = max(int(bbox_cx - image_size / 2), 0)
210
+ bottom = top + image_size
211
+ right = left + image_size
212
+ else:
213
+ raise ValueError(f"Unknown mask type: {self.mask_type}")
214
+
215
+ frame = frame.crop((left, top, right, bottom))
216
+ frame = frame.resize((self.reso, self.reso))
217
+ frames.append(self.transform(frame))
218
+
219
+ if self.load_pixelnerf:
220
+ # extrinsics
221
+ extrinsics = this_images[this_id]
222
+ c2w = qt2c2w(extrinsics.qvec, extrinsics.tvec)
223
+ # intrinsics
224
+ intrinsics = read_cameras_binary(this_camera_dir / "cameras.bin")
225
+ assert len(intrinsics) == 1
226
+ intrinsics = intrinsics[1]
227
+ f, cx, cy, _ = intrinsics.params
228
+ f *= 1 / image_size
229
+ cx -= left
230
+ cy -= top
231
+ cx *= 1 / image_size
232
+ cy *= 1 / image_size # all are relative values
233
+ intrinsics = np.array([[f, 0, cx], [0, f, cy], [0, 0, 1]])
234
+
235
+ this_camera = np.zeros(25)
236
+ this_camera[:16] = c2w.reshape(-1)
237
+ this_camera[16:] = intrinsics.reshape(-1)
238
+
239
+ cameras.append(this_camera)
240
+ downsampled = frame.resize((self.reso // 8, self.reso // 8))
241
+ downsampled_rgb.append((self.transform(downsampled) + 1.0) * 0.5)
242
+
243
+ data = dict()
244
+
245
+ cond_aug = np.exp(
246
+ np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
247
+ )
248
+ frames = torch.stack(frames)
249
+ cond = frames[0]
250
+ # setting all things in data
251
+ data["frames"] = frames
252
+ data["cond_frames_without_noise"] = cond
253
+ data["cond_aug"] = torch.as_tensor([cond_aug] * self.num_frames)
254
+ data["cond_frames"] = cond + cond_aug * torch.randn_like(cond)
255
+ data["fps_id"] = torch.as_tensor([self.fps_id] * self.num_frames)
256
+ data["motion_bucket_id"] = torch.as_tensor(
257
+ [self.motion_bucket_id] * self.num_frames
258
+ )
259
+ data["num_video_frames"] = self.num_frames
260
+ data["image_only_indicator"] = torch.as_tensor([0.0] * self.num_frames)
261
+
262
+ if self.load_pixelnerf:
263
+ # TODO: normalize camera poses
264
+ data["pixelnerf_input"] = dict()
265
+ data["pixelnerf_input"]["frames"] = frames
266
+ data["pixelnerf_input"]["rgb"] = torch.stack(downsampled_rgb)
267
+
268
+ cameras = torch.from_numpy(np.stack(cameras)).float()
269
+ if self.scale_pose:
270
+ c2ws = cameras[..., :16].reshape(-1, 4, 4)
271
+ center = c2ws[:, :3, 3].mean(0)
272
+ radius = (c2ws[:, :3, 3] - center).norm(dim=-1).max()
273
+ scale = 1.5 / radius
274
+ c2ws[..., :3, 3] = (c2ws[..., :3, 3] - center) * scale
275
+ cameras[..., :16] = c2ws.reshape(-1, 16)
276
+
277
+ # if self.max_n_cond > 1:
278
+ # # TODO implement this
279
+ # n_cond = np.random.randint(1, self.max_n_cond + 1)
280
+ # # debug
281
+ # source_index = [0]
282
+ # if n_cond > 1:
283
+ # source_index += np.random.choice(
284
+ # np.arange(1, self.num_frames),
285
+ # self.max_n_cond - 1,
286
+ # replace=False,
287
+ # ).tolist()
288
+ # data["pixelnerf_input"]["source_index"] = torch.as_tensor(
289
+ # source_index
290
+ # )
291
+ # data["pixelnerf_input"]["n_cond"] = n_cond
292
+ # data["pixelnerf_input"]["source_images"] = frames[source_index]
293
+ # data["pixelnerf_input"]["source_cameras"] = cameras[source_index]
294
+
295
+ data["pixelnerf_input"]["cameras"] = cameras
296
+
297
+ return data
298
+
299
+ def __len__(self):
300
+ return len(self.ids)
301
+
302
+ def collate_fn(self, batch):
303
+ # a hack to add source index and keep consistent within a batch
304
+ if self.max_n_cond > 1:
305
+ # TODO implement this
306
+ n_cond = np.random.randint(self.min_n_cond, self.max_n_cond + 1)
307
+ # debug
308
+ # source_index = [0]
309
+ if n_cond > 1:
310
+ for b in batch:
311
+ source_index = [0] + np.random.choice(
312
+ np.arange(1, self.num_frames),
313
+ self.max_n_cond - 1,
314
+ replace=False,
315
+ ).tolist()
316
+ b["pixelnerf_input"]["source_index"] = torch.as_tensor(source_index)
317
+ b["pixelnerf_input"]["n_cond"] = n_cond
318
+ b["pixelnerf_input"]["source_images"] = b["frames"][source_index]
319
+ b["pixelnerf_input"]["source_cameras"] = b["pixelnerf_input"][
320
+ "cameras"
321
+ ][source_index]
322
+
323
+ if self.cond_on_multi:
324
+ b["cond_frames_without_noise"] = b["frames"][source_index]
325
+
326
+ ret = video_collate_fn(batch)
327
+
328
+ if self.cond_on_multi:
329
+ ret["cond_frames_without_noise"] = rearrange(ret["cond_frames_without_noise"], "b t ... -> (b t) ...")
330
+
331
+ return ret
332
+
333
+
334
+ class MVImageNetFixedCond(MVImageNet):
335
+ def __init__(self, *args, **kwargs):
336
+ super().__init__(*args, **kwargs)
337
+
338
+
339
+ class MVImageNetDataset(LightningDataModule):
340
+ def __init__(
341
+ self,
342
+ root_dir,
343
+ batch_size=2,
344
+ shuffle=True,
345
+ num_workers=10,
346
+ prefetch_factor=2,
347
+ **kwargs,
348
+ ):
349
+ super().__init__()
350
+
351
+ self.batch_size = batch_size
352
+ self.num_workers = num_workers
353
+ self.prefetch_factor = prefetch_factor
354
+ self.shuffle = shuffle
355
+
356
+ self.transform = Compose(
357
+ [
358
+ ToTensor(),
359
+ Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
360
+ ]
361
+ )
362
+
363
+ self.train_dataset = MVImageNet(
364
+ root_dir=root_dir,
365
+ split="train",
366
+ transform=self.transform,
367
+ **kwargs,
368
+ )
369
+
370
+ self.test_dataset = MVImageNet(
371
+ root_dir=root_dir,
372
+ split="test",
373
+ transform=self.transform,
374
+ **kwargs,
375
+ )
376
+
377
+ def train_dataloader(self):
378
+ def worker_init_fn(worker_id):
379
+ np.random.seed(np.random.get_state()[1][0])
380
+
381
+ return DataLoader(
382
+ self.train_dataset,
383
+ batch_size=self.batch_size,
384
+ shuffle=self.shuffle,
385
+ num_workers=self.num_workers,
386
+ prefetch_factor=self.prefetch_factor,
387
+ collate_fn=self.train_dataset.collate_fn,
388
+ )
389
+
390
+ def test_dataloader(self):
391
+ return DataLoader(
392
+ self.test_dataset,
393
+ batch_size=self.batch_size,
394
+ shuffle=self.shuffle,
395
+ num_workers=self.num_workers,
396
+ prefetch_factor=self.prefetch_factor,
397
+ collate_fn=self.test_dataset.collate_fn,
398
+ )
399
+
400
+ def val_dataloader(self):
401
+ return DataLoader(
402
+ self.test_dataset,
403
+ batch_size=self.batch_size,
404
+ shuffle=self.shuffle,
405
+ num_workers=self.num_workers,
406
+ prefetch_factor=self.prefetch_factor,
407
+ collate_fn=video_collate_fn,
408
+ )
sgm/data/objaverse.py ADDED
@@ -0,0 +1,882 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from pathlib import Path
3
+ from PIL import Image
4
+ import json
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch.utils.data import Dataset, DataLoader, default_collate
8
+ from torchvision.transforms import ToTensor, Normalize, Compose, Resize
9
+ from torchvision.transforms.functional import to_tensor
10
+ from pytorch_lightning import LightningDataModule
11
+ from einops import rearrange
12
+
13
+
14
+ def read_camera_matrix_single(json_file):
15
+ # for gobjaverse
16
+ with open(json_file, "r", encoding="utf8") as reader:
17
+ json_content = json.load(reader)
18
+
19
+ # negative sign for opencv to opengl
20
+ camera_matrix = torch.zeros(3, 4)
21
+ camera_matrix[:3, 0] = torch.tensor(json_content["x"])
22
+ camera_matrix[:3, 1] = -torch.tensor(json_content["y"])
23
+ camera_matrix[:3, 2] = -torch.tensor(json_content["z"])
24
+ camera_matrix[:3, 3] = torch.tensor(json_content["origin"])
25
+ """
26
+ camera_matrix = np.eye(4)
27
+ camera_matrix[:3, 0] = np.array(json_content['x'])
28
+ camera_matrix[:3, 1] = np.array(json_content['y'])
29
+ camera_matrix[:3, 2] = np.array(json_content['z'])
30
+ camera_matrix[:3, 3] = np.array(json_content['origin'])
31
+ # print(camera_matrix)
32
+ """
33
+
34
+ return camera_matrix
35
+
36
+
37
+ def read_camera_instrinsics_single(json_file, h: int, w: int, scale: float = 1.0):
38
+ with open(json_file, "r", encoding="utf8") as reader:
39
+ json_content = json.load(reader)
40
+
41
+ h = int(h * scale)
42
+ w = int(w * scale)
43
+
44
+ y_fov = json_content["y_fov"]
45
+ x_fov = json_content["x_fov"]
46
+
47
+ fy = h / 2 / np.tan(y_fov / 2)
48
+ fx = w / 2 / np.tan(x_fov / 2)
49
+
50
+ cx = w // 2
51
+ cy = h // 2
52
+
53
+ intrinsics = torch.tensor(
54
+ [
55
+ [fx, fy],
56
+ [cx, cy],
57
+ [w, h],
58
+ ],
59
+ dtype=torch.float32,
60
+ )
61
+ return intrinsics
62
+
63
+
64
+ def compose_extrinsic_RT(RT: torch.Tensor):
65
+ """
66
+ Compose the standard form extrinsic matrix from RT.
67
+ Batched I/O.
68
+ """
69
+ return torch.cat(
70
+ [
71
+ RT,
72
+ torch.tensor([[[0, 0, 0, 1]]], dtype=torch.float32).repeat(
73
+ RT.shape[0], 1, 1
74
+ ),
75
+ ],
76
+ dim=1,
77
+ )
78
+
79
+
80
+ def get_normalized_camera_intrinsics(intrinsics: torch.Tensor):
81
+ """
82
+ intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
83
+ Return batched fx, fy, cx, cy
84
+ """
85
+ fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1]
86
+ cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1]
87
+ width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1]
88
+ fx, fy = fx / width, fy / height
89
+ cx, cy = cx / width, cy / height
90
+ return fx, fy, cx, cy
91
+
92
+
93
+ def build_camera_standard(RT: torch.Tensor, intrinsics: torch.Tensor):
94
+ """
95
+ RT: (N, 3, 4)
96
+ intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
97
+ """
98
+ E = compose_extrinsic_RT(RT)
99
+ fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics)
100
+ I = torch.stack(
101
+ [
102
+ torch.stack([fx, torch.zeros_like(fx), cx], dim=-1),
103
+ torch.stack([torch.zeros_like(fy), fy, cy], dim=-1),
104
+ torch.tensor([[0, 0, 1]], dtype=torch.float32).repeat(RT.shape[0], 1),
105
+ ],
106
+ dim=1,
107
+ )
108
+ return torch.cat(
109
+ [
110
+ E.reshape(-1, 16),
111
+ I.reshape(-1, 9),
112
+ ],
113
+ dim=-1,
114
+ )
115
+
116
+
117
+ def calc_elevation(c2w):
118
+ ## works for single or batched c2w
119
+ ## assume world up is (0, 0, 1)
120
+ pos = c2w[..., :3, 3]
121
+
122
+ return np.arcsin(pos[..., 2] / np.linalg.norm(pos, axis=-1, keepdims=False))
123
+
124
+
125
+ def read_camera_matrix_single(json_file):
126
+ with open(json_file, "r", encoding="utf8") as reader:
127
+ json_content = json.load(reader)
128
+
129
+ # negative sign for opencv to opengl
130
+ # camera_matrix = np.zeros([3, 4])
131
+ # camera_matrix[:3, 0] = np.array(json_content["x"])
132
+ # camera_matrix[:3, 1] = -np.array(json_content["y"])
133
+ # camera_matrix[:3, 2] = -np.array(json_content["z"])
134
+ # camera_matrix[:3, 3] = np.array(json_content["origin"])
135
+ camera_matrix = torch.zeros([3, 4])
136
+ camera_matrix[:3, 0] = torch.tensor(json_content["x"])
137
+ camera_matrix[:3, 1] = -torch.tensor(json_content["y"])
138
+ camera_matrix[:3, 2] = -torch.tensor(json_content["z"])
139
+ camera_matrix[:3, 3] = torch.tensor(json_content["origin"])
140
+ """
141
+ camera_matrix = np.eye(4)
142
+ camera_matrix[:3, 0] = np.array(json_content['x'])
143
+ camera_matrix[:3, 1] = np.array(json_content['y'])
144
+ camera_matrix[:3, 2] = np.array(json_content['z'])
145
+ camera_matrix[:3, 3] = np.array(json_content['origin'])
146
+ # print(camera_matrix)
147
+ """
148
+
149
+ return camera_matrix
150
+
151
+
152
+ def blend_white_bg(image):
153
+ new_image = Image.new("RGB", image.size, (255, 255, 255))
154
+ new_image.paste(image, mask=image.split()[3])
155
+
156
+ return new_image
157
+
158
+
159
+ def flatten_for_video(input):
160
+ return input.flatten()
161
+
162
+
163
+ FLATTEN_FIELDS = ["fps_id", "motion_bucket_id", "cond_aug", "elevation"]
164
+
165
+
166
+ def video_collate_fn(batch: list[dict], *args, **kwargs):
167
+ out = {}
168
+ for key in batch[0].keys():
169
+ if key in FLATTEN_FIELDS:
170
+ out[key] = default_collate([item[key] for item in batch])
171
+ out[key] = flatten_for_video(out[key])
172
+ elif key == "num_video_frames":
173
+ out[key] = batch[0][key]
174
+ elif key in ["frames", "latents", "rgb"]:
175
+ out[key] = default_collate([item[key] for item in batch])
176
+ out[key] = rearrange(out[key], "b t c h w -> (b t) c h w")
177
+ else:
178
+ out[key] = default_collate([item[key] for item in batch])
179
+
180
+ if "pixelnerf_input" in out:
181
+ out["pixelnerf_input"]["rgb"] = rearrange(
182
+ out["pixelnerf_input"]["rgb"], "b t c h w -> (b t) c h w"
183
+ )
184
+
185
+ return out
186
+
187
+
188
+ class GObjaverse(Dataset):
189
+ def __init__(
190
+ self,
191
+ root_dir,
192
+ split="train",
193
+ transform=None,
194
+ random_front=False,
195
+ max_item=None,
196
+ cond_aug_mean=-3.0,
197
+ cond_aug_std=0.5,
198
+ condition_on_elevation=False,
199
+ fps_id=0.0,
200
+ motion_bucket_id=300.0,
201
+ use_latents=False,
202
+ load_caps=False,
203
+ front_view_selection="random",
204
+ load_pixelnerf=False,
205
+ debug_base_idx=None,
206
+ scale_pose: bool = False,
207
+ max_n_cond: int = 1,
208
+ **unused_kwargs,
209
+ ):
210
+ self.root_dir = Path(root_dir)
211
+ self.split = split
212
+ self.random_front = random_front
213
+ self.transform = transform
214
+ self.use_latents = use_latents
215
+
216
+ self.ids = json.load(open(self.root_dir / "valid_uids.json", "r"))
217
+ self.n_views = 24
218
+
219
+ self.load_caps = load_caps
220
+ if self.load_caps:
221
+ self.caps = json.load(open(self.root_dir / "text_captions_cap3d.json", "r"))
222
+
223
+ self.cond_aug_mean = cond_aug_mean
224
+ self.cond_aug_std = cond_aug_std
225
+ self.condition_on_elevation = condition_on_elevation
226
+ self.fps_id = fps_id
227
+ self.motion_bucket_id = motion_bucket_id
228
+ self.load_pixelnerf = load_pixelnerf
229
+ self.scale_pose = scale_pose
230
+ self.max_n_cond = max_n_cond
231
+
232
+ if self.use_latents:
233
+ self.latents_dir = self.root_dir / "latents256"
234
+ self.clip_dir = self.root_dir / "clip_emb256"
235
+
236
+ self.front_view_selection = front_view_selection
237
+ if self.front_view_selection == "random":
238
+ pass
239
+ elif self.front_view_selection == "fixed":
240
+ pass
241
+ elif self.front_view_selection.startswith("clip_score"):
242
+ self.clip_scores = torch.load(self.root_dir / "clip_score_per_view.pt")
243
+ self.ids = list(self.clip_scores.keys())
244
+ else:
245
+ raise ValueError(
246
+ f"Unknown front view selection method {self.front_view_selection}"
247
+ )
248
+
249
+ if max_item is not None:
250
+ self.ids = self.ids[:max_item]
251
+ ## debug
252
+ self.ids = self.ids * 10000
253
+
254
+ if debug_base_idx is not None:
255
+ print(f"debug mode with base idx: {debug_base_idx}")
256
+ self.debug_base_idx = debug_base_idx
257
+
258
+ def __getitem__(self, idx: int):
259
+ if hasattr(self, "debug_base_idx"):
260
+ idx = (idx + self.debug_base_idx) % len(self.ids)
261
+ data = {}
262
+ idx_list = np.arange(self.n_views)
263
+ # if self.random_front:
264
+ # roll_idx = np.random.randint(self.n_views)
265
+ # idx_list = np.roll(idx_list, roll_idx)
266
+ if self.front_view_selection == "random":
267
+ roll_idx = np.random.randint(self.n_views)
268
+ idx_list = np.roll(idx_list, roll_idx)
269
+ elif self.front_view_selection == "fixed":
270
+ pass
271
+ elif self.front_view_selection == "clip_score_softmax":
272
+ this_clip_score = (
273
+ F.softmax(self.clip_scores[self.ids[idx]], dim=-1).cpu().numpy()
274
+ )
275
+ roll_idx = np.random.choice(idx_list, p=this_clip_score)
276
+ idx_list = np.roll(idx_list, roll_idx)
277
+ elif self.front_view_selection == "clip_score_max":
278
+ this_clip_score = (
279
+ F.softmax(self.clip_scores[self.ids[idx]], dim=-1).cpu().numpy()
280
+ )
281
+ roll_idx = np.argmax(this_clip_score)
282
+ idx_list = np.roll(idx_list, roll_idx)
283
+ frames = []
284
+ if not self.use_latents:
285
+ try:
286
+ for view_idx in idx_list:
287
+ frame = Image.open(
288
+ self.root_dir
289
+ / "gobjaverse"
290
+ / self.ids[idx]
291
+ / f"{view_idx:05d}/{view_idx:05d}.png"
292
+ )
293
+ frames.append(self.transform(frame))
294
+ except:
295
+ idx = 0
296
+ frames = []
297
+ for view_idx in idx_list:
298
+ frame = Image.open(
299
+ self.root_dir
300
+ / "gobjaverse"
301
+ / self.ids[idx]
302
+ / f"{view_idx:05d}/{view_idx:05d}.png"
303
+ )
304
+ frames.append(self.transform(frame))
305
+ # a workaround for some bugs in gobjaverse
306
+ # use idx=0 and the repeat will be resolved when gathering results, valid number of items can be checked by the len of results
307
+ frames = torch.stack(frames, dim=0)
308
+ cond = frames[0]
309
+
310
+ cond_aug = np.exp(
311
+ np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
312
+ )
313
+
314
+ data.update(
315
+ {
316
+ "frames": frames,
317
+ "cond_frames_without_noise": cond,
318
+ "cond_aug": torch.as_tensor([cond_aug] * self.n_views),
319
+ "cond_frames": cond + cond_aug * torch.randn_like(cond),
320
+ "fps_id": torch.as_tensor([self.fps_id] * self.n_views),
321
+ "motion_bucket_id": torch.as_tensor(
322
+ [self.motion_bucket_id] * self.n_views
323
+ ),
324
+ "num_video_frames": 24,
325
+ "image_only_indicator": torch.as_tensor([0.0] * self.n_views),
326
+ }
327
+ )
328
+ else:
329
+ latents = torch.load(self.latents_dir / f"{self.ids[idx]}.pt")[idx_list]
330
+ clip_emb = torch.load(self.clip_dir / f"{self.ids[idx]}.pt")[idx_list][0]
331
+
332
+ cond = latents[0]
333
+
334
+ cond_aug = np.exp(
335
+ np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
336
+ )
337
+
338
+ data.update(
339
+ {
340
+ "latents": latents,
341
+ "cond_frames_without_noise": clip_emb,
342
+ "cond_aug": torch.as_tensor([cond_aug] * self.n_views),
343
+ "cond_frames": cond + cond_aug * torch.randn_like(cond),
344
+ "fps_id": torch.as_tensor([self.fps_id] * self.n_views),
345
+ "motion_bucket_id": torch.as_tensor(
346
+ [self.motion_bucket_id] * self.n_views
347
+ ),
348
+ "num_video_frames": 24,
349
+ "image_only_indicator": torch.as_tensor([0.0] * self.n_views),
350
+ }
351
+ )
352
+
353
+ if self.condition_on_elevation:
354
+ sample_c2w = read_camera_matrix_single(
355
+ self.root_dir / self.ids[idx] / f"00000/00000.json"
356
+ )
357
+ elevation = calc_elevation(sample_c2w)
358
+ data["elevation"] = torch.as_tensor([elevation] * self.n_views)
359
+
360
+ if self.load_pixelnerf:
361
+ assert "frames" in data, f"pixelnerf cannot work with latents only mode"
362
+ data["pixelnerf_input"] = {}
363
+ RTs = []
364
+ intrinsics = []
365
+ for view_idx in idx_list:
366
+ meta = (
367
+ self.root_dir
368
+ / "gobjaverse"
369
+ / self.ids[idx]
370
+ / f"{view_idx:05d}/{view_idx:05d}.json"
371
+ )
372
+ RTs.append(read_camera_matrix_single(meta)[:3])
373
+ intrinsics.append(read_camera_instrinsics_single(meta, 256, 256))
374
+ RTs = torch.stack(RTs, dim=0)
375
+ intrinsics = torch.stack(intrinsics, dim=0)
376
+ cameras = build_camera_standard(RTs, intrinsics)
377
+ data["pixelnerf_input"]["cameras"] = cameras
378
+
379
+ downsampled = []
380
+ for view_idx in idx_list:
381
+ frame = Image.open(
382
+ self.root_dir
383
+ / "gobjaverse"
384
+ / self.ids[idx]
385
+ / f"{view_idx:05d}/{view_idx:05d}.png"
386
+ ).resize((32, 32))
387
+ downsampled.append(to_tensor(blend_white_bg(frame)))
388
+ data["pixelnerf_input"]["rgb"] = torch.stack(downsampled, dim=0)
389
+ data["pixelnerf_input"]["frames"] = data["frames"]
390
+ if self.scale_pose:
391
+ c2ws = cameras[..., :16].reshape(-1, 4, 4)
392
+ center = c2ws[:, :3, 3].mean(0)
393
+ radius = (c2ws[:, :3, 3] - center).norm(dim=-1).max()
394
+ scale = 1.5 / radius
395
+ c2ws[..., :3, 3] = (c2ws[..., :3, 3] - center) * scale
396
+ cameras[..., :16] = c2ws.reshape(-1, 16)
397
+
398
+ if self.load_caps:
399
+ data["caption"] = self.caps[self.ids[idx]]
400
+ data["ids"] = self.ids[idx]
401
+
402
+ return data
403
+
404
+ def __len__(self):
405
+ return len(self.ids)
406
+
407
+ def collate_fn(self, batch):
408
+ if self.max_n_cond > 1:
409
+ n_cond = np.random.randint(1, self.max_n_cond + 1)
410
+ if n_cond > 1:
411
+ for b in batch:
412
+ source_index = [0] + np.random.choice(
413
+ np.arange(1, self.n_views),
414
+ self.max_n_cond - 1,
415
+ replace=False,
416
+ ).tolist()
417
+ b["pixelnerf_input"]["source_index"] = torch.as_tensor(source_index)
418
+ b["pixelnerf_input"]["n_cond"] = n_cond
419
+ b["pixelnerf_input"]["source_images"] = b["frames"][source_index]
420
+ b["pixelnerf_input"]["source_cameras"] = b["pixelnerf_input"][
421
+ "cameras"
422
+ ][source_index]
423
+
424
+ return video_collate_fn(batch)
425
+
426
+
427
+ class ObjaverseSpiral(Dataset):
428
+ def __init__(
429
+ self,
430
+ root_dir,
431
+ split="train",
432
+ transform=None,
433
+ random_front=False,
434
+ max_item=None,
435
+ cond_aug_mean=-3.0,
436
+ cond_aug_std=0.5,
437
+ condition_on_elevation=False,
438
+ **unused_kwargs,
439
+ ):
440
+ self.root_dir = Path(root_dir)
441
+ self.split = split
442
+ self.random_front = random_front
443
+ self.transform = transform
444
+
445
+ self.ids = json.load(open(self.root_dir / f"{split}_ids.json", "r"))
446
+ self.n_views = 24
447
+ valid_ids = []
448
+ for idx in self.ids:
449
+ if (self.root_dir / idx).exists():
450
+ valid_ids.append(idx)
451
+ self.ids = valid_ids
452
+
453
+ self.cond_aug_mean = cond_aug_mean
454
+ self.cond_aug_std = cond_aug_std
455
+ self.condition_on_elevation = condition_on_elevation
456
+
457
+ if max_item is not None:
458
+ self.ids = self.ids[:max_item]
459
+
460
+ ## debug
461
+ self.ids = self.ids * 10000
462
+
463
+ def __getitem__(self, idx: int):
464
+ frames = []
465
+ idx_list = np.arange(self.n_views)
466
+ if self.random_front:
467
+ roll_idx = np.random.randint(self.n_views)
468
+ idx_list = np.roll(idx_list, roll_idx)
469
+ for view_idx in idx_list:
470
+ frame = Image.open(
471
+ self.root_dir / self.ids[idx] / f"{view_idx:05d}/{view_idx:05d}.png"
472
+ )
473
+ frames.append(self.transform(frame))
474
+
475
+ # data = {"jpg": torch.stack(frames, dim=0)} # [T, C, H, W]
476
+ frames = torch.stack(frames, dim=0)
477
+ cond = frames[0]
478
+
479
+ cond_aug = np.exp(
480
+ np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
481
+ )
482
+
483
+ data = {
484
+ "frames": frames,
485
+ "cond_frames_without_noise": cond,
486
+ "cond_aug": torch.as_tensor([cond_aug] * self.n_views),
487
+ "cond_frames": cond + cond_aug * torch.randn_like(cond),
488
+ "fps_id": torch.as_tensor([1.0] * self.n_views),
489
+ "motion_bucket_id": torch.as_tensor([300.0] * self.n_views),
490
+ "num_video_frames": 24,
491
+ "image_only_indicator": torch.as_tensor([0.0] * self.n_views),
492
+ }
493
+
494
+ if self.condition_on_elevation:
495
+ sample_c2w = read_camera_matrix_single(
496
+ self.root_dir / self.ids[idx] / f"00000/00000.json"
497
+ )
498
+ elevation = calc_elevation(sample_c2w)
499
+ data["elevation"] = torch.as_tensor([elevation] * self.n_views)
500
+
501
+ return data
502
+
503
+ def __len__(self):
504
+ return len(self.ids)
505
+
506
+
507
+ class ObjaverseLVISSpiral(Dataset):
508
+ def __init__(
509
+ self,
510
+ root_dir,
511
+ split="train",
512
+ transform=None,
513
+ random_front=False,
514
+ max_item=None,
515
+ cond_aug_mean=-3.0,
516
+ cond_aug_std=0.5,
517
+ condition_on_elevation=False,
518
+ use_precomputed_latents=False,
519
+ **unused_kwargs,
520
+ ):
521
+ print("Using LVIS subset")
522
+ self.root_dir = Path(root_dir)
523
+ self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512")
524
+ self.split = split
525
+ self.random_front = random_front
526
+ self.transform = transform
527
+ self.use_precomputed_latents = use_precomputed_latents
528
+
529
+ self.ids = json.load(open("./assets/lvis_uids.json", "r"))
530
+ self.n_views = 18
531
+ valid_ids = []
532
+ for idx in self.ids:
533
+ if (self.root_dir / idx).exists():
534
+ valid_ids.append(idx)
535
+ self.ids = valid_ids
536
+ print("=" * 30)
537
+ print("Number of valid ids: ", len(self.ids))
538
+ print("=" * 30)
539
+
540
+ self.cond_aug_mean = cond_aug_mean
541
+ self.cond_aug_std = cond_aug_std
542
+ self.condition_on_elevation = condition_on_elevation
543
+
544
+ if max_item is not None:
545
+ self.ids = self.ids[:max_item]
546
+
547
+ ## debug
548
+ self.ids = self.ids * 10000
549
+
550
+ def __getitem__(self, idx: int):
551
+ frames = []
552
+ idx_list = np.arange(self.n_views)
553
+ if self.random_front:
554
+ roll_idx = np.random.randint(self.n_views)
555
+ idx_list = np.roll(idx_list, roll_idx)
556
+ for view_idx in idx_list:
557
+ frame = Image.open(
558
+ self.root_dir
559
+ / self.ids[idx]
560
+ / "elevations_0"
561
+ / f"colors_{view_idx * 2}.png"
562
+ )
563
+ frames.append(self.transform(frame))
564
+
565
+ frames = torch.stack(frames, dim=0)
566
+ cond = frames[0]
567
+
568
+ cond_aug = np.exp(
569
+ np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
570
+ )
571
+
572
+ data = {
573
+ "frames": frames,
574
+ "cond_frames_without_noise": cond,
575
+ "cond_aug": torch.as_tensor([cond_aug] * self.n_views),
576
+ "cond_frames": cond + cond_aug * torch.randn_like(cond),
577
+ "fps_id": torch.as_tensor([0.0] * self.n_views),
578
+ "motion_bucket_id": torch.as_tensor([300.0] * self.n_views),
579
+ "num_video_frames": self.n_views,
580
+ "image_only_indicator": torch.as_tensor([0.0] * self.n_views),
581
+ }
582
+
583
+ if self.use_precomputed_latents:
584
+ data["latents"] = torch.load(self.latent_dir / f"{self.ids[idx]}.pt")
585
+
586
+ if self.condition_on_elevation:
587
+ # sample_c2w = read_camera_matrix_single(
588
+ # self.root_dir / self.ids[idx] / f"00000/00000.json"
589
+ # )
590
+ # elevation = calc_elevation(sample_c2w)
591
+ # data["elevation"] = torch.as_tensor([elevation] * self.n_views)
592
+ assert False, "currently assumes elevation 0"
593
+
594
+ return data
595
+
596
+ def __len__(self):
597
+ return len(self.ids)
598
+
599
+
600
+ class ObjaverseALLSpiral(ObjaverseLVISSpiral):
601
+ def __init__(
602
+ self,
603
+ root_dir,
604
+ split="train",
605
+ transform=None,
606
+ random_front=False,
607
+ max_item=None,
608
+ cond_aug_mean=-3.0,
609
+ cond_aug_std=0.5,
610
+ condition_on_elevation=False,
611
+ use_precomputed_latents=False,
612
+ **unused_kwargs,
613
+ ):
614
+ print("Using ALL objects in Objaverse")
615
+ self.root_dir = Path(root_dir)
616
+ self.split = split
617
+ self.random_front = random_front
618
+ self.transform = transform
619
+ self.use_precomputed_latents = use_precomputed_latents
620
+ self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512")
621
+
622
+ self.ids = json.load(open("./assets/all_ids.json", "r"))
623
+ self.n_views = 18
624
+ valid_ids = []
625
+ for idx in self.ids:
626
+ if (self.root_dir / idx).exists() and (self.root_dir / idx).is_dir():
627
+ valid_ids.append(idx)
628
+ self.ids = valid_ids
629
+ print("=" * 30)
630
+ print("Number of valid ids: ", len(self.ids))
631
+ print("=" * 30)
632
+
633
+ self.cond_aug_mean = cond_aug_mean
634
+ self.cond_aug_std = cond_aug_std
635
+ self.condition_on_elevation = condition_on_elevation
636
+
637
+ if max_item is not None:
638
+ self.ids = self.ids[:max_item]
639
+
640
+ ## debug
641
+ self.ids = self.ids * 10000
642
+
643
+
644
+ class ObjaverseWithPose(Dataset):
645
+ def __init__(
646
+ self,
647
+ root_dir,
648
+ split="train",
649
+ transform=None,
650
+ random_front=False,
651
+ max_item=None,
652
+ cond_aug_mean=-3.0,
653
+ cond_aug_std=0.5,
654
+ condition_on_elevation=False,
655
+ use_precomputed_latents=False,
656
+ **unused_kwargs,
657
+ ):
658
+ print("Using Objaverse with poses")
659
+ self.root_dir = Path(root_dir)
660
+ self.split = split
661
+ self.random_front = random_front
662
+ self.transform = transform
663
+ self.use_precomputed_latents = use_precomputed_latents
664
+ self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512")
665
+
666
+ self.ids = json.load(open("./assets/all_ids.json", "r"))
667
+ self.n_views = 18
668
+ valid_ids = []
669
+ for idx in self.ids:
670
+ if (self.root_dir / idx).exists() and (self.root_dir / idx).is_dir():
671
+ valid_ids.append(idx)
672
+ self.ids = valid_ids
673
+ print("=" * 30)
674
+ print("Number of valid ids: ", len(self.ids))
675
+ print("=" * 30)
676
+
677
+ self.cond_aug_mean = cond_aug_mean
678
+ self.cond_aug_std = cond_aug_std
679
+ self.condition_on_elevation = condition_on_elevation
680
+
681
+ def __getitem__(self, idx: int):
682
+ frames = []
683
+ idx_list = np.arange(self.n_views)
684
+ if self.random_front:
685
+ roll_idx = np.random.randint(self.n_views)
686
+ idx_list = np.roll(idx_list, roll_idx)
687
+ for view_idx in idx_list:
688
+ frame = Image.open(
689
+ self.root_dir
690
+ / self.ids[idx]
691
+ / "elevations_0"
692
+ / f"colors_{view_idx * 2}.png"
693
+ )
694
+ frames.append(self.transform(frame))
695
+
696
+ frames = torch.stack(frames, dim=0)
697
+ cond = frames[0]
698
+
699
+ cond_aug = np.exp(
700
+ np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
701
+ )
702
+
703
+ data = {
704
+ "frames": frames,
705
+ "cond_frames_without_noise": cond,
706
+ "cond_aug": torch.as_tensor([cond_aug] * self.n_views),
707
+ "cond_frames": cond + cond_aug * torch.randn_like(cond),
708
+ "fps_id": torch.as_tensor([0.0] * self.n_views),
709
+ "motion_bucket_id": torch.as_tensor([300.0] * self.n_views),
710
+ "num_video_frames": self.n_views,
711
+ "image_only_indicator": torch.as_tensor([0.0] * self.n_views),
712
+ }
713
+
714
+ if self.use_precomputed_latents:
715
+ data["latents"] = torch.load(self.latent_dir / f"{self.ids[idx]}.pt")
716
+
717
+ if self.condition_on_elevation:
718
+ assert False, "currently assumes elevation 0"
719
+
720
+ return data
721
+
722
+
723
+ class LatentObjaverse(Dataset):
724
+ def __init__(
725
+ self,
726
+ root_dir,
727
+ split="train",
728
+ random_front=False,
729
+ subset="lvis",
730
+ fps_id=1.0,
731
+ motion_bucket_id=300.0,
732
+ cond_aug_mean=-3.0,
733
+ cond_aug_std=0.5,
734
+ **unused_kwargs,
735
+ ):
736
+ self.root_dir = Path(root_dir)
737
+ self.split = split
738
+ self.random_front = random_front
739
+ self.ids = json.load(open(Path("./assets") / f"{subset}_ids.json", "r"))
740
+ self.clip_emb_dir = self.root_dir / ".." / "clip_emb512"
741
+ self.n_views = 18
742
+ self.fps_id = fps_id
743
+ self.motion_bucket_id = motion_bucket_id
744
+ self.cond_aug_mean = cond_aug_mean
745
+ self.cond_aug_std = cond_aug_std
746
+ if self.random_front:
747
+ print("Using a random view as front view")
748
+
749
+ valid_ids = []
750
+ for idx in self.ids:
751
+ if (self.root_dir / f"{idx}.pt").exists() and (
752
+ self.clip_emb_dir / f"{idx}.pt"
753
+ ).exists():
754
+ valid_ids.append(idx)
755
+ self.ids = valid_ids
756
+ print("=" * 30)
757
+ print("Number of valid ids: ", len(self.ids))
758
+ print("=" * 30)
759
+
760
+ def __getitem__(self, idx: int):
761
+ uid = self.ids[idx]
762
+ idx_list = torch.arange(self.n_views)
763
+ latents = torch.load(self.root_dir / f"{uid}.pt")
764
+ clip_emb = torch.load(self.clip_emb_dir / f"{uid}.pt")
765
+ if self.random_front:
766
+ idx_list = torch.roll(idx_list, np.random.randint(self.n_views))
767
+ latents = latents[idx_list]
768
+ clip_emb = clip_emb[idx_list][0]
769
+
770
+ cond_aug = np.exp(
771
+ np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
772
+ )
773
+ cond = latents[0]
774
+
775
+ data = {
776
+ "latents": latents,
777
+ "cond_frames_without_noise": clip_emb,
778
+ "cond_frames": cond + cond_aug * torch.randn_like(cond),
779
+ "fps_id": torch.as_tensor([self.fps_id] * self.n_views),
780
+ "motion_bucket_id": torch.as_tensor([self.motion_bucket_id] * self.n_views),
781
+ "cond_aug": torch.as_tensor([cond_aug] * self.n_views),
782
+ "num_video_frames": self.n_views,
783
+ "image_only_indicator": torch.as_tensor([0.0] * self.n_views),
784
+ }
785
+
786
+ return data
787
+
788
+ def __len__(self):
789
+ return len(self.ids)
790
+
791
+
792
+ class ObjaverseSpiralDataset(LightningDataModule):
793
+ def __init__(
794
+ self,
795
+ root_dir,
796
+ random_front=False,
797
+ batch_size=2,
798
+ num_workers=10,
799
+ prefetch_factor=2,
800
+ shuffle=True,
801
+ max_item=None,
802
+ dataset_cls="richdreamer",
803
+ reso: int = 256,
804
+ **kwargs,
805
+ ) -> None:
806
+ super().__init__()
807
+
808
+ self.batch_size = batch_size
809
+ self.num_workers = num_workers
810
+ self.prefetch_factor = prefetch_factor
811
+ self.shuffle = shuffle
812
+ self.max_item = max_item
813
+
814
+ self.transform = Compose(
815
+ [
816
+ blend_white_bg,
817
+ Resize((reso, reso)),
818
+ ToTensor(),
819
+ Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
820
+ ]
821
+ )
822
+
823
+ data_cls = {
824
+ "richdreamer": ObjaverseSpiral,
825
+ "lvis": ObjaverseLVISSpiral,
826
+ "shengshu_all": ObjaverseALLSpiral,
827
+ "latent": LatentObjaverse,
828
+ "gobjaverse": GObjaverse,
829
+ }[dataset_cls]
830
+
831
+ self.train_dataset = data_cls(
832
+ root_dir=root_dir,
833
+ split="train",
834
+ random_front=random_front,
835
+ transform=self.transform,
836
+ max_item=self.max_item,
837
+ **kwargs,
838
+ )
839
+ self.test_dataset = data_cls(
840
+ root_dir=root_dir,
841
+ split="val",
842
+ random_front=random_front,
843
+ transform=self.transform,
844
+ max_item=self.max_item,
845
+ **kwargs,
846
+ )
847
+
848
+ def train_dataloader(self):
849
+ return DataLoader(
850
+ self.train_dataset,
851
+ batch_size=self.batch_size,
852
+ shuffle=self.shuffle,
853
+ num_workers=self.num_workers,
854
+ prefetch_factor=self.prefetch_factor,
855
+ collate_fn=video_collate_fn
856
+ if not hasattr(self.train_dataset, "collate_fn")
857
+ else self.train_dataset.collate_fn,
858
+ )
859
+
860
+ def test_dataloader(self):
861
+ return DataLoader(
862
+ self.test_dataset,
863
+ batch_size=self.batch_size,
864
+ shuffle=self.shuffle,
865
+ num_workers=self.num_workers,
866
+ prefetch_factor=self.prefetch_factor,
867
+ collate_fn=video_collate_fn
868
+ if not hasattr(self.test_dataset, "collate_fn")
869
+ else self.train_dataset.collate_fn,
870
+ )
871
+
872
+ def val_dataloader(self):
873
+ return DataLoader(
874
+ self.test_dataset,
875
+ batch_size=self.batch_size,
876
+ shuffle=self.shuffle,
877
+ num_workers=self.num_workers,
878
+ prefetch_factor=self.prefetch_factor,
879
+ collate_fn=video_collate_fn
880
+ if not hasattr(self.test_dataset, "collate_fn")
881
+ else self.train_dataset.collate_fn,
882
+ )
sgm/inference/api.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ from dataclasses import asdict, dataclass
3
+ from enum import Enum
4
+ from typing import Optional
5
+
6
+ from omegaconf import OmegaConf
7
+
8
+ from sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img,
9
+ do_sample)
10
+ from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
11
+ DPMPP2SAncestralSampler,
12
+ EulerAncestralSampler,
13
+ EulerEDMSampler,
14
+ HeunEDMSampler,
15
+ LinearMultistepSampler)
16
+ from sgm.util import load_model_from_config
17
+
18
+
19
+ class ModelArchitecture(str, Enum):
20
+ SD_2_1 = "stable-diffusion-v2-1"
21
+ SD_2_1_768 = "stable-diffusion-v2-1-768"
22
+ SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
23
+ SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
24
+ SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
25
+ SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
26
+
27
+
28
+ class Sampler(str, Enum):
29
+ EULER_EDM = "EulerEDMSampler"
30
+ HEUN_EDM = "HeunEDMSampler"
31
+ EULER_ANCESTRAL = "EulerAncestralSampler"
32
+ DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler"
33
+ DPMPP2M = "DPMPP2MSampler"
34
+ LINEAR_MULTISTEP = "LinearMultistepSampler"
35
+
36
+
37
+ class Discretization(str, Enum):
38
+ LEGACY_DDPM = "LegacyDDPMDiscretization"
39
+ EDM = "EDMDiscretization"
40
+
41
+
42
+ class Guider(str, Enum):
43
+ VANILLA = "VanillaCFG"
44
+ IDENTITY = "IdentityGuider"
45
+
46
+
47
+ class Thresholder(str, Enum):
48
+ NONE = "None"
49
+
50
+
51
+ @dataclass
52
+ class SamplingParams:
53
+ width: int = 1024
54
+ height: int = 1024
55
+ steps: int = 50
56
+ sampler: Sampler = Sampler.DPMPP2M
57
+ discretization: Discretization = Discretization.LEGACY_DDPM
58
+ guider: Guider = Guider.VANILLA
59
+ thresholder: Thresholder = Thresholder.NONE
60
+ scale: float = 6.0
61
+ aesthetic_score: float = 5.0
62
+ negative_aesthetic_score: float = 5.0
63
+ img2img_strength: float = 1.0
64
+ orig_width: int = 1024
65
+ orig_height: int = 1024
66
+ crop_coords_top: int = 0
67
+ crop_coords_left: int = 0
68
+ sigma_min: float = 0.0292
69
+ sigma_max: float = 14.6146
70
+ rho: float = 3.0
71
+ s_churn: float = 0.0
72
+ s_tmin: float = 0.0
73
+ s_tmax: float = 999.0
74
+ s_noise: float = 1.0
75
+ eta: float = 1.0
76
+ order: int = 4
77
+
78
+
79
+ @dataclass
80
+ class SamplingSpec:
81
+ width: int
82
+ height: int
83
+ channels: int
84
+ factor: int
85
+ is_legacy: bool
86
+ config: str
87
+ ckpt: str
88
+ is_guided: bool
89
+
90
+
91
+ model_specs = {
92
+ ModelArchitecture.SD_2_1: SamplingSpec(
93
+ height=512,
94
+ width=512,
95
+ channels=4,
96
+ factor=8,
97
+ is_legacy=True,
98
+ config="sd_2_1.yaml",
99
+ ckpt="v2-1_512-ema-pruned.safetensors",
100
+ is_guided=True,
101
+ ),
102
+ ModelArchitecture.SD_2_1_768: SamplingSpec(
103
+ height=768,
104
+ width=768,
105
+ channels=4,
106
+ factor=8,
107
+ is_legacy=True,
108
+ config="sd_2_1_768.yaml",
109
+ ckpt="v2-1_768-ema-pruned.safetensors",
110
+ is_guided=True,
111
+ ),
112
+ ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
113
+ height=1024,
114
+ width=1024,
115
+ channels=4,
116
+ factor=8,
117
+ is_legacy=False,
118
+ config="sd_xl_base.yaml",
119
+ ckpt="sd_xl_base_0.9.safetensors",
120
+ is_guided=True,
121
+ ),
122
+ ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
123
+ height=1024,
124
+ width=1024,
125
+ channels=4,
126
+ factor=8,
127
+ is_legacy=True,
128
+ config="sd_xl_refiner.yaml",
129
+ ckpt="sd_xl_refiner_0.9.safetensors",
130
+ is_guided=True,
131
+ ),
132
+ ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
133
+ height=1024,
134
+ width=1024,
135
+ channels=4,
136
+ factor=8,
137
+ is_legacy=False,
138
+ config="sd_xl_base.yaml",
139
+ ckpt="sd_xl_base_1.0.safetensors",
140
+ is_guided=True,
141
+ ),
142
+ ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
143
+ height=1024,
144
+ width=1024,
145
+ channels=4,
146
+ factor=8,
147
+ is_legacy=True,
148
+ config="sd_xl_refiner.yaml",
149
+ ckpt="sd_xl_refiner_1.0.safetensors",
150
+ is_guided=True,
151
+ ),
152
+ }
153
+
154
+
155
+ class SamplingPipeline:
156
+ def __init__(
157
+ self,
158
+ model_id: ModelArchitecture,
159
+ model_path="checkpoints",
160
+ config_path="configs/inference",
161
+ device="cuda",
162
+ use_fp16=True,
163
+ ) -> None:
164
+ if model_id not in model_specs:
165
+ raise ValueError(f"Model {model_id} not supported")
166
+ self.model_id = model_id
167
+ self.specs = model_specs[self.model_id]
168
+ self.config = str(pathlib.Path(config_path, self.specs.config))
169
+ self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
170
+ self.device = device
171
+ self.model = self._load_model(device=device, use_fp16=use_fp16)
172
+
173
+ def _load_model(self, device="cuda", use_fp16=True):
174
+ config = OmegaConf.load(self.config)
175
+ model = load_model_from_config(config, self.ckpt)
176
+ if model is None:
177
+ raise ValueError(f"Model {self.model_id} could not be loaded")
178
+ model.to(device)
179
+ if use_fp16:
180
+ model.conditioner.half()
181
+ model.model.half()
182
+ return model
183
+
184
+ def text_to_image(
185
+ self,
186
+ params: SamplingParams,
187
+ prompt: str,
188
+ negative_prompt: str = "",
189
+ samples: int = 1,
190
+ return_latents: bool = False,
191
+ ):
192
+ sampler = get_sampler_config(params)
193
+ value_dict = asdict(params)
194
+ value_dict["prompt"] = prompt
195
+ value_dict["negative_prompt"] = negative_prompt
196
+ value_dict["target_width"] = params.width
197
+ value_dict["target_height"] = params.height
198
+ return do_sample(
199
+ self.model,
200
+ sampler,
201
+ value_dict,
202
+ samples,
203
+ params.height,
204
+ params.width,
205
+ self.specs.channels,
206
+ self.specs.factor,
207
+ force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
208
+ return_latents=return_latents,
209
+ filter=None,
210
+ )
211
+
212
+ def image_to_image(
213
+ self,
214
+ params: SamplingParams,
215
+ image,
216
+ prompt: str,
217
+ negative_prompt: str = "",
218
+ samples: int = 1,
219
+ return_latents: bool = False,
220
+ ):
221
+ sampler = get_sampler_config(params)
222
+
223
+ if params.img2img_strength < 1.0:
224
+ sampler.discretization = Img2ImgDiscretizationWrapper(
225
+ sampler.discretization,
226
+ strength=params.img2img_strength,
227
+ )
228
+ height, width = image.shape[2], image.shape[3]
229
+ value_dict = asdict(params)
230
+ value_dict["prompt"] = prompt
231
+ value_dict["negative_prompt"] = negative_prompt
232
+ value_dict["target_width"] = width
233
+ value_dict["target_height"] = height
234
+ return do_img2img(
235
+ image,
236
+ self.model,
237
+ sampler,
238
+ value_dict,
239
+ samples,
240
+ force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
241
+ return_latents=return_latents,
242
+ filter=None,
243
+ )
244
+
245
+ def refiner(
246
+ self,
247
+ params: SamplingParams,
248
+ image,
249
+ prompt: str,
250
+ negative_prompt: Optional[str] = None,
251
+ samples: int = 1,
252
+ return_latents: bool = False,
253
+ ):
254
+ sampler = get_sampler_config(params)
255
+ value_dict = {
256
+ "orig_width": image.shape[3] * 8,
257
+ "orig_height": image.shape[2] * 8,
258
+ "target_width": image.shape[3] * 8,
259
+ "target_height": image.shape[2] * 8,
260
+ "prompt": prompt,
261
+ "negative_prompt": negative_prompt,
262
+ "crop_coords_top": 0,
263
+ "crop_coords_left": 0,
264
+ "aesthetic_score": 6.0,
265
+ "negative_aesthetic_score": 2.5,
266
+ }
267
+
268
+ return do_img2img(
269
+ image,
270
+ self.model,
271
+ sampler,
272
+ value_dict,
273
+ samples,
274
+ skip_encode=True,
275
+ return_latents=return_latents,
276
+ filter=None,
277
+ )
278
+
279
+
280
+ def get_guider_config(params: SamplingParams):
281
+ if params.guider == Guider.IDENTITY:
282
+ guider_config = {
283
+ "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
284
+ }
285
+ elif params.guider == Guider.VANILLA:
286
+ scale = params.scale
287
+
288
+ thresholder = params.thresholder
289
+
290
+ if thresholder == Thresholder.NONE:
291
+ dyn_thresh_config = {
292
+ "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
293
+ }
294
+ else:
295
+ raise NotImplementedError
296
+
297
+ guider_config = {
298
+ "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
299
+ "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
300
+ }
301
+ else:
302
+ raise NotImplementedError
303
+ return guider_config
304
+
305
+
306
+ def get_discretization_config(params: SamplingParams):
307
+ if params.discretization == Discretization.LEGACY_DDPM:
308
+ discretization_config = {
309
+ "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
310
+ }
311
+ elif params.discretization == Discretization.EDM:
312
+ discretization_config = {
313
+ "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
314
+ "params": {
315
+ "sigma_min": params.sigma_min,
316
+ "sigma_max": params.sigma_max,
317
+ "rho": params.rho,
318
+ },
319
+ }
320
+ else:
321
+ raise ValueError(f"unknown discretization {params.discretization}")
322
+ return discretization_config
323
+
324
+
325
+ def get_sampler_config(params: SamplingParams):
326
+ discretization_config = get_discretization_config(params)
327
+ guider_config = get_guider_config(params)
328
+ sampler = None
329
+ if params.sampler == Sampler.EULER_EDM:
330
+ return EulerEDMSampler(
331
+ num_steps=params.steps,
332
+ discretization_config=discretization_config,
333
+ guider_config=guider_config,
334
+ s_churn=params.s_churn,
335
+ s_tmin=params.s_tmin,
336
+ s_tmax=params.s_tmax,
337
+ s_noise=params.s_noise,
338
+ verbose=True,
339
+ )
340
+ if params.sampler == Sampler.HEUN_EDM:
341
+ return HeunEDMSampler(
342
+ num_steps=params.steps,
343
+ discretization_config=discretization_config,
344
+ guider_config=guider_config,
345
+ s_churn=params.s_churn,
346
+ s_tmin=params.s_tmin,
347
+ s_tmax=params.s_tmax,
348
+ s_noise=params.s_noise,
349
+ verbose=True,
350
+ )
351
+ if params.sampler == Sampler.EULER_ANCESTRAL:
352
+ return EulerAncestralSampler(
353
+ num_steps=params.steps,
354
+ discretization_config=discretization_config,
355
+ guider_config=guider_config,
356
+ eta=params.eta,
357
+ s_noise=params.s_noise,
358
+ verbose=True,
359
+ )
360
+ if params.sampler == Sampler.DPMPP2S_ANCESTRAL:
361
+ return DPMPP2SAncestralSampler(
362
+ num_steps=params.steps,
363
+ discretization_config=discretization_config,
364
+ guider_config=guider_config,
365
+ eta=params.eta,
366
+ s_noise=params.s_noise,
367
+ verbose=True,
368
+ )
369
+ if params.sampler == Sampler.DPMPP2M:
370
+ return DPMPP2MSampler(
371
+ num_steps=params.steps,
372
+ discretization_config=discretization_config,
373
+ guider_config=guider_config,
374
+ verbose=True,
375
+ )
376
+ if params.sampler == Sampler.LINEAR_MULTISTEP:
377
+ return LinearMultistepSampler(
378
+ num_steps=params.steps,
379
+ discretization_config=discretization_config,
380
+ guider_config=guider_config,
381
+ order=params.order,
382
+ verbose=True,
383
+ )
384
+
385
+ raise ValueError(f"unknown sampler {params.sampler}!")
sgm/inference/helpers.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from typing import List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from einops import rearrange
8
+ from imwatermark import WatermarkEncoder
9
+ from omegaconf import ListConfig
10
+ from PIL import Image
11
+ from torch import autocast
12
+
13
+ from sgm.util import append_dims
14
+
15
+
16
+ class WatermarkEmbedder:
17
+ def __init__(self, watermark):
18
+ self.watermark = watermark
19
+ self.num_bits = len(WATERMARK_BITS)
20
+ self.encoder = WatermarkEncoder()
21
+ self.encoder.set_watermark("bits", self.watermark)
22
+
23
+ def __call__(self, image: torch.Tensor) -> torch.Tensor:
24
+ """
25
+ Adds a predefined watermark to the input image
26
+
27
+ Args:
28
+ image: ([N,] B, RGB, H, W) in range [0, 1]
29
+
30
+ Returns:
31
+ same as input but watermarked
32
+ """
33
+ squeeze = len(image.shape) == 4
34
+ if squeeze:
35
+ image = image[None, ...]
36
+ n = image.shape[0]
37
+ image_np = rearrange(
38
+ (255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
39
+ ).numpy()[:, :, :, ::-1]
40
+ # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
41
+ # watermarking libary expects input as cv2 BGR format
42
+ for k in range(image_np.shape[0]):
43
+ image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
44
+ image = torch.from_numpy(
45
+ rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
46
+ ).to(image.device)
47
+ image = torch.clamp(image / 255, min=0.0, max=1.0)
48
+ if squeeze:
49
+ image = image[0]
50
+ return image
51
+
52
+
53
+ # A fixed 48-bit message that was choosen at random
54
+ # WATERMARK_MESSAGE = 0xB3EC907BB19E
55
+ WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
56
+ # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
57
+ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
58
+ embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
59
+
60
+
61
+ def get_unique_embedder_keys_from_conditioner(conditioner):
62
+ return list({x.input_key for x in conditioner.embedders})
63
+
64
+
65
+ def perform_save_locally(save_path, samples):
66
+ os.makedirs(os.path.join(save_path), exist_ok=True)
67
+ base_count = len(os.listdir(os.path.join(save_path)))
68
+ samples = embed_watermark(samples)
69
+ for sample in samples:
70
+ sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
71
+ Image.fromarray(sample.astype(np.uint8)).save(
72
+ os.path.join(save_path, f"{base_count:09}.png")
73
+ )
74
+ base_count += 1
75
+
76
+
77
+ class Img2ImgDiscretizationWrapper:
78
+ """
79
+ wraps a discretizer, and prunes the sigmas
80
+ params:
81
+ strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
82
+ """
83
+
84
+ def __init__(self, discretization, strength: float = 1.0):
85
+ self.discretization = discretization
86
+ self.strength = strength
87
+ assert 0.0 <= self.strength <= 1.0
88
+
89
+ def __call__(self, *args, **kwargs):
90
+ # sigmas start large first, and decrease then
91
+ sigmas = self.discretization(*args, **kwargs)
92
+ print(f"sigmas after discretization, before pruning img2img: ", sigmas)
93
+ sigmas = torch.flip(sigmas, (0,))
94
+ sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
95
+ print("prune index:", max(int(self.strength * len(sigmas)), 1))
96
+ sigmas = torch.flip(sigmas, (0,))
97
+ print(f"sigmas after pruning: ", sigmas)
98
+ return sigmas
99
+
100
+
101
+ def do_sample(
102
+ model,
103
+ sampler,
104
+ value_dict,
105
+ num_samples,
106
+ H,
107
+ W,
108
+ C,
109
+ F,
110
+ force_uc_zero_embeddings: Optional[List] = None,
111
+ batch2model_input: Optional[List] = None,
112
+ return_latents=False,
113
+ filter=None,
114
+ device="cuda",
115
+ ):
116
+ if force_uc_zero_embeddings is None:
117
+ force_uc_zero_embeddings = []
118
+ if batch2model_input is None:
119
+ batch2model_input = []
120
+
121
+ with torch.no_grad():
122
+ with autocast(device) as precision_scope:
123
+ with model.ema_scope():
124
+ num_samples = [num_samples]
125
+ batch, batch_uc = get_batch(
126
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
127
+ value_dict,
128
+ num_samples,
129
+ )
130
+ for key in batch:
131
+ if isinstance(batch[key], torch.Tensor):
132
+ print(key, batch[key].shape)
133
+ elif isinstance(batch[key], list):
134
+ print(key, [len(l) for l in batch[key]])
135
+ else:
136
+ print(key, batch[key])
137
+ c, uc = model.conditioner.get_unconditional_conditioning(
138
+ batch,
139
+ batch_uc=batch_uc,
140
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
141
+ )
142
+
143
+ for k in c:
144
+ if not k == "crossattn":
145
+ c[k], uc[k] = map(
146
+ lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
147
+ )
148
+
149
+ additional_model_inputs = {}
150
+ for k in batch2model_input:
151
+ additional_model_inputs[k] = batch[k]
152
+
153
+ shape = (math.prod(num_samples), C, H // F, W // F)
154
+ randn = torch.randn(shape).to(device)
155
+
156
+ def denoiser(input, sigma, c):
157
+ return model.denoiser(
158
+ model.model, input, sigma, c, **additional_model_inputs
159
+ )
160
+
161
+ samples_z = sampler(denoiser, randn, cond=c, uc=uc)
162
+ samples_x = model.decode_first_stage(samples_z)
163
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
164
+
165
+ if filter is not None:
166
+ samples = filter(samples)
167
+
168
+ if return_latents:
169
+ return samples, samples_z
170
+ return samples
171
+
172
+
173
+ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
174
+ # Hardcoded demo setups; might undergo some changes in the future
175
+
176
+ batch = {}
177
+ batch_uc = {}
178
+
179
+ for key in keys:
180
+ if key == "txt":
181
+ batch["txt"] = (
182
+ np.repeat([value_dict["prompt"]], repeats=math.prod(N))
183
+ .reshape(N)
184
+ .tolist()
185
+ )
186
+ batch_uc["txt"] = (
187
+ np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
188
+ .reshape(N)
189
+ .tolist()
190
+ )
191
+ elif key == "original_size_as_tuple":
192
+ batch["original_size_as_tuple"] = (
193
+ torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
194
+ .to(device)
195
+ .repeat(*N, 1)
196
+ )
197
+ elif key == "crop_coords_top_left":
198
+ batch["crop_coords_top_left"] = (
199
+ torch.tensor(
200
+ [value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
201
+ )
202
+ .to(device)
203
+ .repeat(*N, 1)
204
+ )
205
+ elif key == "aesthetic_score":
206
+ batch["aesthetic_score"] = (
207
+ torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
208
+ )
209
+ batch_uc["aesthetic_score"] = (
210
+ torch.tensor([value_dict["negative_aesthetic_score"]])
211
+ .to(device)
212
+ .repeat(*N, 1)
213
+ )
214
+
215
+ elif key == "target_size_as_tuple":
216
+ batch["target_size_as_tuple"] = (
217
+ torch.tensor([value_dict["target_height"], value_dict["target_width"]])
218
+ .to(device)
219
+ .repeat(*N, 1)
220
+ )
221
+ else:
222
+ batch[key] = value_dict[key]
223
+
224
+ for key in batch.keys():
225
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
226
+ batch_uc[key] = torch.clone(batch[key])
227
+ return batch, batch_uc
228
+
229
+
230
+ def get_input_image_tensor(image: Image.Image, device="cuda"):
231
+ w, h = image.size
232
+ print(f"loaded input image of size ({w}, {h})")
233
+ width, height = map(
234
+ lambda x: x - x % 64, (w, h)
235
+ ) # resize to integer multiple of 64
236
+ image = image.resize((width, height))
237
+ image_array = np.array(image.convert("RGB"))
238
+ image_array = image_array[None].transpose(0, 3, 1, 2)
239
+ image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0
240
+ return image_tensor.to(device)
241
+
242
+
243
+ def do_img2img(
244
+ img,
245
+ model,
246
+ sampler,
247
+ value_dict,
248
+ num_samples,
249
+ force_uc_zero_embeddings=[],
250
+ additional_kwargs={},
251
+ offset_noise_level: float = 0.0,
252
+ return_latents=False,
253
+ skip_encode=False,
254
+ filter=None,
255
+ device="cuda",
256
+ ):
257
+ with torch.no_grad():
258
+ with autocast(device) as precision_scope:
259
+ with model.ema_scope():
260
+ batch, batch_uc = get_batch(
261
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
262
+ value_dict,
263
+ [num_samples],
264
+ )
265
+ c, uc = model.conditioner.get_unconditional_conditioning(
266
+ batch,
267
+ batch_uc=batch_uc,
268
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
269
+ )
270
+
271
+ for k in c:
272
+ c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
273
+
274
+ for k in additional_kwargs:
275
+ c[k] = uc[k] = additional_kwargs[k]
276
+ if skip_encode:
277
+ z = img
278
+ else:
279
+ z = model.encode_first_stage(img)
280
+ noise = torch.randn_like(z)
281
+ sigmas = sampler.discretization(sampler.num_steps)
282
+ sigma = sigmas[0].to(z.device)
283
+
284
+ if offset_noise_level > 0.0:
285
+ noise = noise + offset_noise_level * append_dims(
286
+ torch.randn(z.shape[0], device=z.device), z.ndim
287
+ )
288
+ noised_z = z + noise * append_dims(sigma, z.ndim)
289
+ noised_z = noised_z / torch.sqrt(
290
+ 1.0 + sigmas[0] ** 2.0
291
+ ) # Note: hardcoded to DDPM-like scaling. need to generalize later.
292
+
293
+ def denoiser(x, sigma, c):
294
+ return model.denoiser(model.model, x, sigma, c)
295
+
296
+ samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
297
+ samples_x = model.decode_first_stage(samples_z)
298
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
299
+
300
+ if filter is not None:
301
+ samples = filter(samples)
302
+
303
+ if return_latents:
304
+ return samples, samples_z
305
+ return samples
sgm/lr_scheduler.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+
9
+ def __init__(
10
+ self,
11
+ warm_up_steps,
12
+ lr_min,
13
+ lr_max,
14
+ lr_start,
15
+ max_decay_steps,
16
+ verbosity_interval=0,
17
+ ):
18
+ self.lr_warm_up_steps = warm_up_steps
19
+ self.lr_start = lr_start
20
+ self.lr_min = lr_min
21
+ self.lr_max = lr_max
22
+ self.lr_max_decay_steps = max_decay_steps
23
+ self.last_lr = 0.0
24
+ self.verbosity_interval = verbosity_interval
25
+
26
+ def schedule(self, n, **kwargs):
27
+ if self.verbosity_interval > 0:
28
+ if n % self.verbosity_interval == 0:
29
+ print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
30
+ if n < self.lr_warm_up_steps:
31
+ lr = (
32
+ self.lr_max - self.lr_start
33
+ ) / self.lr_warm_up_steps * n + self.lr_start
34
+ self.last_lr = lr
35
+ return lr
36
+ else:
37
+ t = (n - self.lr_warm_up_steps) / (
38
+ self.lr_max_decay_steps - self.lr_warm_up_steps
39
+ )
40
+ t = min(t, 1.0)
41
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
42
+ 1 + np.cos(t * np.pi)
43
+ )
44
+ self.last_lr = lr
45
+ return lr
46
+
47
+ def __call__(self, n, **kwargs):
48
+ return self.schedule(n, **kwargs)
49
+
50
+
51
+ class LambdaWarmUpCosineScheduler2:
52
+ """
53
+ supports repeated iterations, configurable via lists
54
+ note: use with a base_lr of 1.0.
55
+ """
56
+
57
+ def __init__(
58
+ self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
59
+ ):
60
+ assert (
61
+ len(warm_up_steps)
62
+ == len(f_min)
63
+ == len(f_max)
64
+ == len(f_start)
65
+ == len(cycle_lengths)
66
+ )
67
+ self.lr_warm_up_steps = warm_up_steps
68
+ self.f_start = f_start
69
+ self.f_min = f_min
70
+ self.f_max = f_max
71
+ self.cycle_lengths = cycle_lengths
72
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
73
+ self.last_f = 0.0
74
+ self.verbosity_interval = verbosity_interval
75
+
76
+ def find_in_interval(self, n):
77
+ interval = 0
78
+ for cl in self.cum_cycles[1:]:
79
+ if n <= cl:
80
+ return interval
81
+ interval += 1
82
+
83
+ def schedule(self, n, **kwargs):
84
+ cycle = self.find_in_interval(n)
85
+ n = n - self.cum_cycles[cycle]
86
+ if self.verbosity_interval > 0:
87
+ if n % self.verbosity_interval == 0:
88
+ print(
89
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
90
+ f"current cycle {cycle}"
91
+ )
92
+ if n < self.lr_warm_up_steps[cycle]:
93
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
94
+ cycle
95
+ ] * n + self.f_start[cycle]
96
+ self.last_f = f
97
+ return f
98
+ else:
99
+ t = (n - self.lr_warm_up_steps[cycle]) / (
100
+ self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
101
+ )
102
+ t = min(t, 1.0)
103
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
104
+ 1 + np.cos(t * np.pi)
105
+ )
106
+ self.last_f = f
107
+ return f
108
+
109
+ def __call__(self, n, **kwargs):
110
+ return self.schedule(n, **kwargs)
111
+
112
+
113
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
114
+ def schedule(self, n, **kwargs):
115
+ cycle = self.find_in_interval(n)
116
+ n = n - self.cum_cycles[cycle]
117
+ if self.verbosity_interval > 0:
118
+ if n % self.verbosity_interval == 0:
119
+ print(
120
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
121
+ f"current cycle {cycle}"
122
+ )
123
+
124
+ if n < self.lr_warm_up_steps[cycle]:
125
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
126
+ cycle
127
+ ] * n + self.f_start[cycle]
128
+ self.last_f = f
129
+ return f
130
+ else:
131
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
132
+ self.cycle_lengths[cycle] - n
133
+ ) / (self.cycle_lengths[cycle])
134
+ self.last_f = f
135
+ return f
sgm/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .autoencoder import AutoencodingEngine
2
+ from .diffusion import DiffusionEngine
sgm/models/autoencoder.py ADDED
@@ -0,0 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import re
4
+ from abc import abstractmethod
5
+ from contextlib import contextmanager
6
+ from typing import Any, Dict, List, Optional, Tuple, Union
7
+
8
+ import pytorch_lightning as pl
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+ from packaging import version
13
+
14
+ from ..modules.autoencoding.regularizers import AbstractRegularizer
15
+ from ..modules.ema import LitEma
16
+ from ..util import (default, get_nested_attribute, get_obj_from_str,
17
+ instantiate_from_config)
18
+
19
+ logpy = logging.getLogger(__name__)
20
+
21
+
22
+ class AbstractAutoencoder(pl.LightningModule):
23
+ """
24
+ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
25
+ unCLIP models, etc. Hence, it is fairly general, and specific features
26
+ (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ ema_decay: Union[None, float] = None,
32
+ monitor: Union[None, str] = None,
33
+ input_key: str = "jpg",
34
+ ):
35
+ super().__init__()
36
+
37
+ self.input_key = input_key
38
+ self.use_ema = ema_decay is not None
39
+ if monitor is not None:
40
+ self.monitor = monitor
41
+
42
+ if self.use_ema:
43
+ self.model_ema = LitEma(self, decay=ema_decay)
44
+ logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
45
+
46
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
47
+ self.automatic_optimization = False
48
+
49
+ def apply_ckpt(self, ckpt: Union[None, str, dict]):
50
+ if ckpt is None:
51
+ return
52
+ if isinstance(ckpt, str):
53
+ ckpt = {
54
+ "target": "sgm.modules.checkpoint.CheckpointEngine",
55
+ "params": {"ckpt_path": ckpt},
56
+ }
57
+ engine = instantiate_from_config(ckpt)
58
+ engine(self)
59
+
60
+ @abstractmethod
61
+ def get_input(self, batch) -> Any:
62
+ raise NotImplementedError()
63
+
64
+ def on_train_batch_end(self, *args, **kwargs):
65
+ # for EMA computation
66
+ if self.use_ema:
67
+ self.model_ema(self)
68
+
69
+ @contextmanager
70
+ def ema_scope(self, context=None):
71
+ if self.use_ema:
72
+ self.model_ema.store(self.parameters())
73
+ self.model_ema.copy_to(self)
74
+ if context is not None:
75
+ logpy.info(f"{context}: Switched to EMA weights")
76
+ try:
77
+ yield None
78
+ finally:
79
+ if self.use_ema:
80
+ self.model_ema.restore(self.parameters())
81
+ if context is not None:
82
+ logpy.info(f"{context}: Restored training weights")
83
+
84
+ @abstractmethod
85
+ def encode(self, *args, **kwargs) -> torch.Tensor:
86
+ raise NotImplementedError("encode()-method of abstract base class called")
87
+
88
+ @abstractmethod
89
+ def decode(self, *args, **kwargs) -> torch.Tensor:
90
+ raise NotImplementedError("decode()-method of abstract base class called")
91
+
92
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
93
+ logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
94
+ return get_obj_from_str(cfg["target"])(
95
+ params, lr=lr, **cfg.get("params", dict())
96
+ )
97
+
98
+ def configure_optimizers(self) -> Any:
99
+ raise NotImplementedError()
100
+
101
+
102
+ class AutoencodingEngine(AbstractAutoencoder):
103
+ """
104
+ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
105
+ (we also restore them explicitly as special cases for legacy reasons).
106
+ Regularizations such as KL or VQ are moved to the regularizer class.
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ *args,
112
+ encoder_config: Dict,
113
+ decoder_config: Dict,
114
+ loss_config: Dict,
115
+ regularizer_config: Dict,
116
+ optimizer_config: Union[Dict, None] = None,
117
+ lr_g_factor: float = 1.0,
118
+ trainable_ae_params: Optional[List[List[str]]] = None,
119
+ ae_optimizer_args: Optional[List[dict]] = None,
120
+ trainable_disc_params: Optional[List[List[str]]] = None,
121
+ disc_optimizer_args: Optional[List[dict]] = None,
122
+ disc_start_iter: int = 0,
123
+ diff_boost_factor: float = 3.0,
124
+ ckpt_engine: Union[None, str, dict] = None,
125
+ ckpt_path: Optional[str] = None,
126
+ additional_decode_keys: Optional[List[str]] = None,
127
+ **kwargs,
128
+ ):
129
+ super().__init__(*args, **kwargs)
130
+ self.automatic_optimization = False # pytorch lightning
131
+
132
+ self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
133
+ self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
134
+ self.loss: torch.nn.Module = instantiate_from_config(loss_config)
135
+ self.regularization: AbstractRegularizer = instantiate_from_config(
136
+ regularizer_config
137
+ )
138
+ self.optimizer_config = default(
139
+ optimizer_config, {"target": "torch.optim.Adam"}
140
+ )
141
+ self.diff_boost_factor = diff_boost_factor
142
+ self.disc_start_iter = disc_start_iter
143
+ self.lr_g_factor = lr_g_factor
144
+ self.trainable_ae_params = trainable_ae_params
145
+ if self.trainable_ae_params is not None:
146
+ self.ae_optimizer_args = default(
147
+ ae_optimizer_args,
148
+ [{} for _ in range(len(self.trainable_ae_params))],
149
+ )
150
+ assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
151
+ else:
152
+ self.ae_optimizer_args = [{}] # makes type consitent
153
+
154
+ self.trainable_disc_params = trainable_disc_params
155
+ if self.trainable_disc_params is not None:
156
+ self.disc_optimizer_args = default(
157
+ disc_optimizer_args,
158
+ [{} for _ in range(len(self.trainable_disc_params))],
159
+ )
160
+ assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
161
+ else:
162
+ self.disc_optimizer_args = [{}] # makes type consitent
163
+
164
+ if ckpt_path is not None:
165
+ assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
166
+ logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
167
+ self.apply_ckpt(default(ckpt_path, ckpt_engine))
168
+ self.additional_decode_keys = set(default(additional_decode_keys, []))
169
+
170
+ def get_input(self, batch: Dict) -> torch.Tensor:
171
+ # assuming unified data format, dataloader returns a dict.
172
+ # image tensors should be scaled to -1 ... 1 and in channels-first
173
+ # format (e.g., bchw instead if bhwc)
174
+ return batch[self.input_key]
175
+
176
+ def get_autoencoder_params(self) -> list:
177
+ params = []
178
+ if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
179
+ params += list(self.loss.get_trainable_autoencoder_parameters())
180
+ if hasattr(self.regularization, "get_trainable_parameters"):
181
+ params += list(self.regularization.get_trainable_parameters())
182
+ params = params + list(self.encoder.parameters())
183
+ params = params + list(self.decoder.parameters())
184
+ return params
185
+
186
+ def get_discriminator_params(self) -> list:
187
+ if hasattr(self.loss, "get_trainable_parameters"):
188
+ params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
189
+ else:
190
+ params = []
191
+ return params
192
+
193
+ def get_last_layer(self):
194
+ return self.decoder.get_last_layer()
195
+
196
+ def encode(
197
+ self,
198
+ x: torch.Tensor,
199
+ return_reg_log: bool = False,
200
+ unregularized: bool = False,
201
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
202
+ z = self.encoder(x)
203
+ if unregularized:
204
+ return z, dict()
205
+ z, reg_log = self.regularization(z)
206
+ if return_reg_log:
207
+ return z, reg_log
208
+ return z
209
+
210
+ def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
211
+ x = self.decoder(z, **kwargs)
212
+ return x
213
+
214
+ def forward(
215
+ self, x: torch.Tensor, **additional_decode_kwargs
216
+ ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
217
+ z, reg_log = self.encode(x, return_reg_log=True)
218
+ dec = self.decode(z, **additional_decode_kwargs)
219
+ return z, dec, reg_log
220
+
221
+ def inner_training_step(
222
+ self, batch: dict, batch_idx: int, optimizer_idx: int = 0
223
+ ) -> torch.Tensor:
224
+ x = self.get_input(batch)
225
+ additional_decode_kwargs = {
226
+ key: batch[key] for key in self.additional_decode_keys.intersection(batch)
227
+ }
228
+ z, xrec, regularization_log = self(x, **additional_decode_kwargs)
229
+ if hasattr(self.loss, "forward_keys"):
230
+ extra_info = {
231
+ "z": z,
232
+ "optimizer_idx": optimizer_idx,
233
+ "global_step": self.global_step,
234
+ "last_layer": self.get_last_layer(),
235
+ "split": "train",
236
+ "regularization_log": regularization_log,
237
+ "autoencoder": self,
238
+ }
239
+ extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
240
+ else:
241
+ extra_info = dict()
242
+
243
+ if optimizer_idx == 0:
244
+ # autoencode
245
+ out_loss = self.loss(x, xrec, **extra_info)
246
+ if isinstance(out_loss, tuple):
247
+ aeloss, log_dict_ae = out_loss
248
+ else:
249
+ # simple loss function
250
+ aeloss = out_loss
251
+ log_dict_ae = {"train/loss/rec": aeloss.detach()}
252
+
253
+ self.log_dict(
254
+ log_dict_ae,
255
+ prog_bar=False,
256
+ logger=True,
257
+ on_step=True,
258
+ on_epoch=True,
259
+ sync_dist=False,
260
+ )
261
+ self.log(
262
+ "loss",
263
+ aeloss.mean().detach(),
264
+ prog_bar=True,
265
+ logger=False,
266
+ on_epoch=False,
267
+ on_step=True,
268
+ )
269
+ return aeloss
270
+ elif optimizer_idx == 1:
271
+ # discriminator
272
+ discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
273
+ # -> discriminator always needs to return a tuple
274
+ self.log_dict(
275
+ log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
276
+ )
277
+ return discloss
278
+ else:
279
+ raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
280
+
281
+ def training_step(self, batch: dict, batch_idx: int):
282
+ opts = self.optimizers()
283
+ if not isinstance(opts, list):
284
+ # Non-adversarial case
285
+ opts = [opts]
286
+ optimizer_idx = batch_idx % len(opts)
287
+ if self.global_step < self.disc_start_iter:
288
+ optimizer_idx = 0
289
+ opt = opts[optimizer_idx]
290
+ opt.zero_grad()
291
+ with opt.toggle_model():
292
+ loss = self.inner_training_step(
293
+ batch, batch_idx, optimizer_idx=optimizer_idx
294
+ )
295
+ self.manual_backward(loss)
296
+ opt.step()
297
+
298
+ def validation_step(self, batch: dict, batch_idx: int) -> Dict:
299
+ log_dict = self._validation_step(batch, batch_idx)
300
+ with self.ema_scope():
301
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
302
+ log_dict.update(log_dict_ema)
303
+ return log_dict
304
+
305
+ def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
306
+ x = self.get_input(batch)
307
+
308
+ z, xrec, regularization_log = self(x)
309
+ if hasattr(self.loss, "forward_keys"):
310
+ extra_info = {
311
+ "z": z,
312
+ "optimizer_idx": 0,
313
+ "global_step": self.global_step,
314
+ "last_layer": self.get_last_layer(),
315
+ "split": "val" + postfix,
316
+ "regularization_log": regularization_log,
317
+ "autoencoder": self,
318
+ }
319
+ extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
320
+ else:
321
+ extra_info = dict()
322
+ out_loss = self.loss(x, xrec, **extra_info)
323
+ if isinstance(out_loss, tuple):
324
+ aeloss, log_dict_ae = out_loss
325
+ else:
326
+ # simple loss function
327
+ aeloss = out_loss
328
+ log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
329
+ full_log_dict = log_dict_ae
330
+
331
+ if "optimizer_idx" in extra_info:
332
+ extra_info["optimizer_idx"] = 1
333
+ discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
334
+ full_log_dict.update(log_dict_disc)
335
+ self.log(
336
+ f"val{postfix}/loss/rec",
337
+ log_dict_ae[f"val{postfix}/loss/rec"],
338
+ sync_dist=True,
339
+ )
340
+ self.log_dict(full_log_dict, sync_dist=True)
341
+ return full_log_dict
342
+
343
+ def get_param_groups(
344
+ self, parameter_names: List[List[str]], optimizer_args: List[dict]
345
+ ) -> Tuple[List[Dict[str, Any]], int]:
346
+ groups = []
347
+ num_params = 0
348
+ for names, args in zip(parameter_names, optimizer_args):
349
+ params = []
350
+ for pattern_ in names:
351
+ pattern_params = []
352
+ pattern = re.compile(pattern_)
353
+ for p_name, param in self.named_parameters():
354
+ if re.match(pattern, p_name):
355
+ pattern_params.append(param)
356
+ num_params += param.numel()
357
+ if len(pattern_params) == 0:
358
+ logpy.warn(f"Did not find parameters for pattern {pattern_}")
359
+ params.extend(pattern_params)
360
+ groups.append({"params": params, **args})
361
+ return groups, num_params
362
+
363
+ def configure_optimizers(self) -> List[torch.optim.Optimizer]:
364
+ if self.trainable_ae_params is None:
365
+ ae_params = self.get_autoencoder_params()
366
+ else:
367
+ ae_params, num_ae_params = self.get_param_groups(
368
+ self.trainable_ae_params, self.ae_optimizer_args
369
+ )
370
+ logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
371
+ if self.trainable_disc_params is None:
372
+ disc_params = self.get_discriminator_params()
373
+ else:
374
+ disc_params, num_disc_params = self.get_param_groups(
375
+ self.trainable_disc_params, self.disc_optimizer_args
376
+ )
377
+ logpy.info(
378
+ f"Number of trainable discriminator parameters: {num_disc_params:,}"
379
+ )
380
+ opt_ae = self.instantiate_optimizer_from_config(
381
+ ae_params,
382
+ default(self.lr_g_factor, 1.0) * self.learning_rate,
383
+ self.optimizer_config,
384
+ )
385
+ opts = [opt_ae]
386
+ if len(disc_params) > 0:
387
+ opt_disc = self.instantiate_optimizer_from_config(
388
+ disc_params, self.learning_rate, self.optimizer_config
389
+ )
390
+ opts.append(opt_disc)
391
+
392
+ return opts
393
+
394
+ @torch.no_grad()
395
+ def log_images(
396
+ self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
397
+ ) -> dict:
398
+ log = dict()
399
+ additional_decode_kwargs = {}
400
+ x = self.get_input(batch)
401
+ additional_decode_kwargs.update(
402
+ {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
403
+ )
404
+
405
+ _, xrec, _ = self(x, **additional_decode_kwargs)
406
+ log["inputs"] = x
407
+ log["reconstructions"] = xrec
408
+ diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
409
+ diff.clamp_(0, 1.0)
410
+ log["diff"] = 2.0 * diff - 1.0
411
+ # diff_boost shows location of small errors, by boosting their
412
+ # brightness.
413
+ log["diff_boost"] = (
414
+ 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
415
+ )
416
+ if hasattr(self.loss, "log_images"):
417
+ log.update(self.loss.log_images(x, xrec))
418
+ with self.ema_scope():
419
+ _, xrec_ema, _ = self(x, **additional_decode_kwargs)
420
+ log["reconstructions_ema"] = xrec_ema
421
+ diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
422
+ diff_ema.clamp_(0, 1.0)
423
+ log["diff_ema"] = 2.0 * diff_ema - 1.0
424
+ log["diff_boost_ema"] = (
425
+ 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
426
+ )
427
+ if additional_log_kwargs:
428
+ additional_decode_kwargs.update(additional_log_kwargs)
429
+ _, xrec_add, _ = self(x, **additional_decode_kwargs)
430
+ log_str = "reconstructions-" + "-".join(
431
+ [f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
432
+ )
433
+ log[log_str] = xrec_add
434
+ return log
435
+
436
+
437
+ class AutoencodingEngineLegacy(AutoencodingEngine):
438
+ def __init__(self, embed_dim: int, **kwargs):
439
+ self.max_batch_size = kwargs.pop("max_batch_size", None)
440
+ ddconfig = kwargs.pop("ddconfig")
441
+ ckpt_path = kwargs.pop("ckpt_path", None)
442
+ ckpt_engine = kwargs.pop("ckpt_engine", None)
443
+ super().__init__(
444
+ encoder_config={
445
+ "target": "sgm.modules.diffusionmodules.model.Encoder",
446
+ "params": ddconfig,
447
+ },
448
+ decoder_config={
449
+ "target": "sgm.modules.diffusionmodules.model.Decoder",
450
+ "params": ddconfig,
451
+ },
452
+ **kwargs,
453
+ )
454
+ self.quant_conv = torch.nn.Conv2d(
455
+ (1 + ddconfig["double_z"]) * ddconfig["z_channels"],
456
+ (1 + ddconfig["double_z"]) * embed_dim,
457
+ 1,
458
+ )
459
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
460
+ self.embed_dim = embed_dim
461
+
462
+ self.apply_ckpt(default(ckpt_path, ckpt_engine))
463
+
464
+ def get_autoencoder_params(self) -> list:
465
+ params = super().get_autoencoder_params()
466
+ return params
467
+
468
+ def encode(
469
+ self, x: torch.Tensor, return_reg_log: bool = False
470
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
471
+ if self.max_batch_size is None:
472
+ z = self.encoder(x)
473
+ z = self.quant_conv(z)
474
+ else:
475
+ N = x.shape[0]
476
+ bs = self.max_batch_size
477
+ n_batches = int(math.ceil(N / bs))
478
+ z = list()
479
+ for i_batch in range(n_batches):
480
+ z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
481
+ z_batch = self.quant_conv(z_batch)
482
+ z.append(z_batch)
483
+ z = torch.cat(z, 0)
484
+
485
+ z, reg_log = self.regularization(z)
486
+ if return_reg_log:
487
+ return z, reg_log
488
+ return z
489
+
490
+ def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
491
+ if self.max_batch_size is None:
492
+ dec = self.post_quant_conv(z)
493
+ dec = self.decoder(dec, **decoder_kwargs)
494
+ else:
495
+ N = z.shape[0]
496
+ bs = self.max_batch_size
497
+ n_batches = int(math.ceil(N / bs))
498
+ dec = list()
499
+ for i_batch in range(n_batches):
500
+ dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
501
+ dec_batch = self.decoder(dec_batch, **decoder_kwargs)
502
+ dec.append(dec_batch)
503
+ dec = torch.cat(dec, 0)
504
+
505
+ return dec
506
+
507
+
508
+ class AutoencoderKL(AutoencodingEngineLegacy):
509
+ def __init__(self, **kwargs):
510
+ if "lossconfig" in kwargs:
511
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
512
+ super().__init__(
513
+ regularizer_config={
514
+ "target": (
515
+ "sgm.modules.autoencoding.regularizers"
516
+ ".DiagonalGaussianRegularizer"
517
+ )
518
+ },
519
+ **kwargs,
520
+ )
521
+
522
+
523
+ class AutoencoderLegacyVQ(AutoencodingEngineLegacy):
524
+ def __init__(
525
+ self,
526
+ embed_dim: int,
527
+ n_embed: int,
528
+ sane_index_shape: bool = False,
529
+ **kwargs,
530
+ ):
531
+ if "lossconfig" in kwargs:
532
+ logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.")
533
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
534
+ super().__init__(
535
+ regularizer_config={
536
+ "target": (
537
+ "sgm.modules.autoencoding.regularizers.quantize" ".VectorQuantizer"
538
+ ),
539
+ "params": {
540
+ "n_e": n_embed,
541
+ "e_dim": embed_dim,
542
+ "sane_index_shape": sane_index_shape,
543
+ },
544
+ },
545
+ **kwargs,
546
+ )
547
+
548
+
549
+ class IdentityFirstStage(AbstractAutoencoder):
550
+ def __init__(self, *args, **kwargs):
551
+ super().__init__(*args, **kwargs)
552
+
553
+ def get_input(self, x: Any) -> Any:
554
+ return x
555
+
556
+ def encode(self, x: Any, *args, **kwargs) -> Any:
557
+ return x
558
+
559
+ def decode(self, x: Any, *args, **kwargs) -> Any:
560
+ return x
561
+
562
+
563
+ class AEIntegerWrapper(nn.Module):
564
+ def __init__(
565
+ self,
566
+ model: nn.Module,
567
+ shape: Union[None, Tuple[int, int], List[int]] = (16, 16),
568
+ regularization_key: str = "regularization",
569
+ encoder_kwargs: Optional[Dict[str, Any]] = None,
570
+ ):
571
+ super().__init__()
572
+ self.model = model
573
+ assert hasattr(model, "encode") and hasattr(
574
+ model, "decode"
575
+ ), "Need AE interface"
576
+ self.regularization = get_nested_attribute(model, regularization_key)
577
+ self.shape = shape
578
+ self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True})
579
+
580
+ def encode(self, x) -> torch.Tensor:
581
+ assert (
582
+ not self.training
583
+ ), f"{self.__class__.__name__} only supports inference currently"
584
+ _, log = self.model.encode(x, **self.encoder_kwargs)
585
+ assert isinstance(log, dict)
586
+ inds = log["min_encoding_indices"]
587
+ return rearrange(inds, "b ... -> b (...)")
588
+
589
+ def decode(
590
+ self, inds: torch.Tensor, shape: Union[None, tuple, list] = None
591
+ ) -> torch.Tensor:
592
+ # expect inds shape (b, s) with s = h*w
593
+ shape = default(shape, self.shape) # Optional[(h, w)]
594
+ if shape is not None:
595
+ assert len(shape) == 2, f"Unhandeled shape {shape}"
596
+ inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1])
597
+ h = self.regularization.get_codebook_entry(inds) # (b, h, w, c)
598
+ h = rearrange(h, "b h w c -> b c h w")
599
+ return self.model.decode(h)
600
+
601
+
602
+ class AutoencoderKLModeOnly(AutoencodingEngineLegacy):
603
+ def __init__(self, **kwargs):
604
+ if "lossconfig" in kwargs:
605
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
606
+ super().__init__(
607
+ regularizer_config={
608
+ "target": (
609
+ "sgm.modules.autoencoding.regularizers"
610
+ ".DiagonalGaussianRegularizer"
611
+ ),
612
+ "params": {"sample": False},
613
+ },
614
+ **kwargs,
615
+ )
sgm/models/diffusion.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from contextlib import contextmanager
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ import pytorch_lightning as pl
6
+ import torch
7
+ from omegaconf import ListConfig, OmegaConf
8
+ from safetensors.torch import load_file as load_safetensors
9
+ from torch.optim.lr_scheduler import LambdaLR
10
+ from einops import rearrange
11
+
12
+ from ..modules import UNCONDITIONAL_CONFIG
13
+ from ..modules.autoencoding.temporal_ae import VideoDecoder
14
+ from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
15
+ from ..modules.ema import LitEma
16
+ from ..util import (
17
+ default,
18
+ disabled_train,
19
+ get_obj_from_str,
20
+ instantiate_from_config,
21
+ log_txt_as_img,
22
+ )
23
+
24
+
25
+ class DiffusionEngine(pl.LightningModule):
26
+ def __init__(
27
+ self,
28
+ network_config,
29
+ denoiser_config,
30
+ first_stage_config,
31
+ conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
32
+ sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
33
+ optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
34
+ scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
35
+ loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
36
+ network_wrapper: Union[None, str] = None,
37
+ ckpt_path: Union[None, str] = None,
38
+ use_ema: bool = False,
39
+ ema_decay_rate: float = 0.9999,
40
+ scale_factor: float = 1.0,
41
+ disable_first_stage_autocast=False,
42
+ input_key: str = "jpg",
43
+ log_keys: Union[List, None] = None,
44
+ no_cond_log: bool = False,
45
+ compile_model: bool = False,
46
+ en_and_decode_n_samples_a_time: Optional[int] = None,
47
+ ):
48
+ super().__init__()
49
+ self.log_keys = log_keys
50
+ self.input_key = input_key
51
+ self.optimizer_config = default(
52
+ optimizer_config, {"target": "torch.optim.AdamW"}
53
+ )
54
+ model = instantiate_from_config(network_config)
55
+ self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
56
+ model, compile_model=compile_model
57
+ )
58
+
59
+ self.denoiser = instantiate_from_config(denoiser_config)
60
+ self.sampler = (
61
+ instantiate_from_config(sampler_config)
62
+ if sampler_config is not None
63
+ else None
64
+ )
65
+ self.conditioner = instantiate_from_config(
66
+ default(conditioner_config, UNCONDITIONAL_CONFIG)
67
+ )
68
+ self.scheduler_config = scheduler_config
69
+ self._init_first_stage(first_stage_config)
70
+
71
+ self.loss_fn = (
72
+ instantiate_from_config(loss_fn_config)
73
+ if loss_fn_config is not None
74
+ else None
75
+ )
76
+
77
+ self.use_ema = use_ema
78
+ if self.use_ema:
79
+ self.model_ema = LitEma(self.model, decay=ema_decay_rate)
80
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
81
+
82
+ self.scale_factor = scale_factor
83
+ self.disable_first_stage_autocast = disable_first_stage_autocast
84
+ self.no_cond_log = no_cond_log
85
+
86
+ if ckpt_path is not None:
87
+ self.init_from_ckpt(ckpt_path)
88
+
89
+ self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
90
+
91
+ def init_from_ckpt(
92
+ self,
93
+ path: str,
94
+ ) -> None:
95
+ if path.endswith("ckpt"):
96
+ sd = torch.load(path, map_location="cpu")["state_dict"]
97
+ elif path.endswith("safetensors"):
98
+ sd = load_safetensors(path)
99
+ else:
100
+ raise NotImplementedError
101
+
102
+ missing, unexpected = self.load_state_dict(sd, strict=False)
103
+ print(
104
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
105
+ )
106
+ if len(missing) > 0:
107
+ print(f"Missing Keys: {missing}")
108
+ if len(unexpected) > 0:
109
+ print(f"Unexpected Keys: {unexpected}")
110
+
111
+ def _init_first_stage(self, config):
112
+ model = instantiate_from_config(config).eval()
113
+ model.train = disabled_train
114
+ for param in model.parameters():
115
+ param.requires_grad = False
116
+ self.first_stage_model = model
117
+
118
+ def get_input(self, batch):
119
+ # assuming unified data format, dataloader returns a dict.
120
+ # image tensors should be scaled to -1 ... 1 and in bchw format
121
+ return batch[self.input_key]
122
+
123
+ @torch.no_grad()
124
+ def decode_first_stage(self, z):
125
+ z = 1.0 / self.scale_factor * z
126
+ n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
127
+
128
+ n_rounds = math.ceil(z.shape[0] / n_samples)
129
+ all_out = []
130
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
131
+ for n in range(n_rounds):
132
+ if isinstance(self.first_stage_model.decoder, VideoDecoder):
133
+ kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
134
+ else:
135
+ kwargs = {}
136
+ out = self.first_stage_model.decode(
137
+ z[n * n_samples : (n + 1) * n_samples], **kwargs
138
+ )
139
+ all_out.append(out)
140
+ out = torch.cat(all_out, dim=0)
141
+ return out
142
+
143
+ @torch.no_grad()
144
+ def encode_first_stage(self, x):
145
+ bs = x.shape[0]
146
+ is_video_input = False
147
+ if x.dim() == 5:
148
+ is_video_input = True
149
+ # for video diffusion
150
+ x = rearrange(x, "b t c h w -> (b t) c h w")
151
+ n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
152
+ n_rounds = math.ceil(x.shape[0] / n_samples)
153
+ all_out = []
154
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
155
+ for n in range(n_rounds):
156
+ out = self.first_stage_model.encode(
157
+ x[n * n_samples : (n + 1) * n_samples]
158
+ )
159
+ all_out.append(out)
160
+ z = torch.cat(all_out, dim=0)
161
+ z = self.scale_factor * z
162
+
163
+ if is_video_input:
164
+ z = rearrange(z, "(b t) c h w -> b t c h w", b=bs)
165
+
166
+ return z
167
+
168
+ def forward(self, x, batch):
169
+ loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
170
+ loss_mean = loss.mean()
171
+ loss_dict = {"loss": loss_mean}
172
+ return loss_mean, loss_dict
173
+
174
+ def shared_step(self, batch: Dict) -> Any:
175
+ x = self.get_input(batch)
176
+ breakpoint()
177
+ x = self.encode_first_stage(x)
178
+ batch["global_step"] = self.global_step
179
+ loss, loss_dict = self(x, batch)
180
+ return loss, loss_dict
181
+
182
+ def training_step(self, batch, batch_idx):
183
+ loss, loss_dict = self.shared_step(batch)
184
+
185
+ self.log_dict(
186
+ loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
187
+ )
188
+
189
+ self.log(
190
+ "global_step",
191
+ self.global_step,
192
+ prog_bar=True,
193
+ logger=True,
194
+ on_step=True,
195
+ on_epoch=False,
196
+ )
197
+
198
+ if self.scheduler_config is not None:
199
+ lr = self.optimizers().param_groups[0]["lr"]
200
+ self.log(
201
+ "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
202
+ )
203
+
204
+ return loss
205
+
206
+ def on_train_start(self, *args, **kwargs):
207
+ if self.sampler is None or self.loss_fn is None:
208
+ raise ValueError("Sampler and loss function need to be set for training.")
209
+
210
+ def on_train_batch_end(self, *args, **kwargs):
211
+ if self.use_ema:
212
+ self.model_ema(self.model)
213
+
214
+ @contextmanager
215
+ def ema_scope(self, context=None):
216
+ if self.use_ema:
217
+ self.model_ema.store(self.model.parameters())
218
+ self.model_ema.copy_to(self.model)
219
+ if context is not None:
220
+ print(f"{context}: Switched to EMA weights")
221
+ try:
222
+ yield None
223
+ finally:
224
+ if self.use_ema:
225
+ self.model_ema.restore(self.model.parameters())
226
+ if context is not None:
227
+ print(f"{context}: Restored training weights")
228
+
229
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
230
+ return get_obj_from_str(cfg["target"])(
231
+ params, lr=lr, **cfg.get("params", dict())
232
+ )
233
+
234
+ def configure_optimizers(self):
235
+ lr = self.learning_rate
236
+ params = list(self.model.parameters())
237
+ for embedder in self.conditioner.embedders:
238
+ if embedder.is_trainable:
239
+ params = params + list(embedder.parameters())
240
+ opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
241
+ if self.scheduler_config is not None:
242
+ scheduler = instantiate_from_config(self.scheduler_config)
243
+ print("Setting up LambdaLR scheduler...")
244
+ scheduler = [
245
+ {
246
+ "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
247
+ "interval": "step",
248
+ "frequency": 1,
249
+ }
250
+ ]
251
+ return [opt], scheduler
252
+ return opt
253
+
254
+ @torch.no_grad()
255
+ def sample(
256
+ self,
257
+ cond: Dict,
258
+ uc: Union[Dict, None] = None,
259
+ batch_size: int = 16,
260
+ shape: Union[None, Tuple, List] = None,
261
+ **kwargs,
262
+ ):
263
+ randn = torch.randn(batch_size, *shape).to(self.device)
264
+
265
+ denoiser = lambda input, sigma, c: self.denoiser(
266
+ self.model, input, sigma, c, **kwargs
267
+ )
268
+ samples = self.sampler(denoiser, randn, cond, uc=uc)
269
+ return samples
270
+
271
+ @torch.no_grad()
272
+ def log_conditionings(self, batch: Dict, n: int) -> Dict:
273
+ """
274
+ Defines heuristics to log different conditionings.
275
+ These can be lists of strings (text-to-image), tensors, ints, ...
276
+ """
277
+ image_h, image_w = batch[self.input_key].shape[2:]
278
+ log = dict()
279
+
280
+ for embedder in self.conditioner.embedders:
281
+ if (
282
+ (self.log_keys is None) or (embedder.input_key in self.log_keys)
283
+ ) and not self.no_cond_log:
284
+ x = batch[embedder.input_key][:n]
285
+ if isinstance(x, torch.Tensor):
286
+ if x.dim() == 1:
287
+ # class-conditional, convert integer to string
288
+ x = [str(x[i].item()) for i in range(x.shape[0])]
289
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
290
+ elif x.dim() == 2:
291
+ # size and crop cond and the like
292
+ x = [
293
+ "x".join([str(xx) for xx in x[i].tolist()])
294
+ for i in range(x.shape[0])
295
+ ]
296
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
297
+ else:
298
+ raise NotImplementedError()
299
+ elif isinstance(x, (List, ListConfig)):
300
+ if isinstance(x[0], str):
301
+ # strings
302
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
303
+ else:
304
+ raise NotImplementedError()
305
+ else:
306
+ raise NotImplementedError()
307
+ log[embedder.input_key] = xc
308
+ return log
309
+
310
+ @torch.no_grad()
311
+ def log_images(
312
+ self,
313
+ batch: Dict,
314
+ N: int = 8,
315
+ sample: bool = True,
316
+ ucg_keys: List[str] = None,
317
+ **kwargs,
318
+ ) -> Dict:
319
+ conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
320
+ if ucg_keys:
321
+ assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
322
+ "Each defined ucg key for sampling must be in the provided conditioner input keys,"
323
+ f"but we have {ucg_keys} vs. {conditioner_input_keys}"
324
+ )
325
+ else:
326
+ ucg_keys = conditioner_input_keys
327
+ log = dict()
328
+
329
+ x = self.get_input(batch)
330
+
331
+ c, uc = self.conditioner.get_unconditional_conditioning(
332
+ batch,
333
+ force_uc_zero_embeddings=ucg_keys
334
+ if len(self.conditioner.embedders) > 0
335
+ else [],
336
+ )
337
+
338
+ sampling_kwargs = {}
339
+
340
+ N = min(x.shape[0], N)
341
+ x = x.to(self.device)[:N]
342
+ log["inputs"] = x
343
+ z = self.encode_first_stage(x)
344
+ log["reconstructions"] = self.decode_first_stage(z)
345
+ log.update(self.log_conditionings(batch, N))
346
+
347
+ for k in c:
348
+ if isinstance(c[k], torch.Tensor):
349
+ c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
350
+
351
+ if sample:
352
+ with self.ema_scope("Plotting"):
353
+ samples = self.sample(
354
+ c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
355
+ )
356
+ samples = self.decode_first_stage(samples)
357
+ log["samples"] = samples
358
+ return log
sgm/models/video3d_diffusion.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import math
3
+ from contextlib import contextmanager
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+
6
+ import pytorch_lightning as pl
7
+ from pytorch_lightning.loggers import WandbLogger
8
+ import torch
9
+ from omegaconf import ListConfig, OmegaConf
10
+ from safetensors.torch import load_file as load_safetensors
11
+ from torch.optim.lr_scheduler import LambdaLR
12
+ from torchvision.utils import make_grid
13
+ from einops import rearrange, repeat
14
+
15
+ from ..modules import UNCONDITIONAL_CONFIG
16
+ from ..modules.autoencoding.temporal_ae import VideoDecoder
17
+ from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
18
+ from ..modules.ema import LitEma
19
+ from ..modules.encoders.modules import VideoPredictionEmbedderWithEncoder
20
+ from ..util import (
21
+ default,
22
+ disabled_train,
23
+ get_obj_from_str,
24
+ instantiate_from_config,
25
+ log_txt_as_img,
26
+ video_frames_as_grid,
27
+ )
28
+
29
+
30
+ def flatten_for_video(input):
31
+ return input.flatten()
32
+
33
+
34
+ class Video3DDiffusionEngine(pl.LightningModule):
35
+ def __init__(
36
+ self,
37
+ network_config,
38
+ denoiser_config,
39
+ first_stage_config,
40
+ conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
41
+ sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
42
+ optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
43
+ scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
44
+ loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
45
+ network_wrapper: Union[None, str] = None,
46
+ ckpt_path: Union[None, str] = None,
47
+ use_ema: bool = False,
48
+ ema_decay_rate: float = 0.9999,
49
+ scale_factor: float = 1.0,
50
+ disable_first_stage_autocast=False,
51
+ input_key: str = "frames", # for video inputs
52
+ log_keys: Union[List, None] = None,
53
+ no_cond_log: bool = False,
54
+ compile_model: bool = False,
55
+ en_and_decode_n_samples_a_time: Optional[int] = None,
56
+ ):
57
+ super().__init__()
58
+ self.log_keys = log_keys
59
+ self.input_key = input_key
60
+ self.optimizer_config = default(
61
+ optimizer_config, {"target": "torch.optim.AdamW"}
62
+ )
63
+ model = instantiate_from_config(network_config)
64
+ self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
65
+ model, compile_model=compile_model
66
+ )
67
+
68
+ self.denoiser = instantiate_from_config(denoiser_config)
69
+ self.sampler = (
70
+ instantiate_from_config(sampler_config)
71
+ if sampler_config is not None
72
+ else None
73
+ )
74
+ self.conditioner = instantiate_from_config(
75
+ default(conditioner_config, UNCONDITIONAL_CONFIG)
76
+ )
77
+ self.scheduler_config = scheduler_config
78
+ self._init_first_stage(first_stage_config)
79
+
80
+ self.loss_fn = (
81
+ instantiate_from_config(loss_fn_config)
82
+ if loss_fn_config is not None
83
+ else None
84
+ )
85
+
86
+ self.use_ema = use_ema
87
+ if self.use_ema:
88
+ self.model_ema = LitEma(self.model, decay=ema_decay_rate)
89
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
90
+
91
+ self.scale_factor = scale_factor
92
+ self.disable_first_stage_autocast = disable_first_stage_autocast
93
+ self.no_cond_log = no_cond_log
94
+
95
+ if ckpt_path is not None:
96
+ self.init_from_ckpt(ckpt_path)
97
+
98
+ self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
99
+
100
+ def _load_last_embedder(self, original_state_dict):
101
+ original_module_name = "conditioner.embedders.3"
102
+ state_dict = dict()
103
+ for k, v in original_state_dict.items():
104
+ m = re.match(rf"^{original_module_name}\.(.*)$", k)
105
+ if m is None:
106
+ continue
107
+ state_dict[m.group(1)] = v
108
+
109
+ idx = -1
110
+ for i in range(len(self.conditioner.embedders)):
111
+ if isinstance(
112
+ self.conditioner.embedders[i], VideoPredictionEmbedderWithEncoder
113
+ ):
114
+ idx = i
115
+
116
+ print(f"Embedder [{idx}] is the frame encoder, make sure this is expected")
117
+
118
+ self.conditioner.embedders[idx].load_state_dict(state_dict)
119
+
120
+ def init_from_ckpt(
121
+ self,
122
+ path: str,
123
+ ) -> None:
124
+ if path.endswith("ckpt"):
125
+ sd = torch.load(path, map_location="cpu")["state_dict"]
126
+ elif path.endswith("safetensors"):
127
+ sd = load_safetensors(path)
128
+ else:
129
+ raise NotImplementedError
130
+
131
+ self_sd = self.state_dict()
132
+ input_keys = [
133
+ "model.diffusion_model.input_blocks.0.0.weight",
134
+ "model_ema.diffusion_modelinput_blocks00weight",
135
+ ]
136
+ for input_key in input_keys:
137
+ if input_key not in sd or input_key not in self_sd:
138
+ continue
139
+
140
+ input_weight = self_sd[input_key]
141
+
142
+ if input_weight.shape != sd[input_key].shape:
143
+ print("Manual init: {}".format(input_key))
144
+ input_weight.zero_()
145
+ input_weight[:, :8, :, :].copy_(sd[input_key])
146
+
147
+ deleted_keys = []
148
+ for k, v in self.state_dict().items():
149
+ # resolve shape dismatch
150
+ if k in sd:
151
+ if v.shape != sd[k].shape:
152
+ del sd[k]
153
+ deleted_keys.append(k)
154
+
155
+ if len(deleted_keys) > 0:
156
+ print(f"Deleted Keys: {deleted_keys}")
157
+
158
+ missing, unexpected = self.load_state_dict(sd, strict=False)
159
+ print(
160
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
161
+ )
162
+ if len(missing) > 0:
163
+ print(f"Missing Keys: {missing}")
164
+ if len(unexpected) > 0:
165
+ print(f"Unexpected Keys: {unexpected}")
166
+ if len(deleted_keys) > 0:
167
+ print(f"Deleted Keys: {deleted_keys}")
168
+
169
+ if len(missing) > 0 or len(unexpected) > 0:
170
+ # means we are loading from a checkpoint that has the old embedder (motion bucket id and fps id)
171
+ print("Modified embedder to support 3d spiral video inputs")
172
+ try:
173
+ self._load_last_embedder(sd)
174
+ except:
175
+ print("Failed to load last embedder, make sure this is expected")
176
+
177
+ def _init_first_stage(self, config):
178
+ model = instantiate_from_config(config).eval()
179
+ model.train = disabled_train
180
+ for param in model.parameters():
181
+ param.requires_grad = False
182
+ self.first_stage_model = model
183
+
184
+ def get_input(self, batch):
185
+ # assuming unified data format, dataloader returns a dict.
186
+ # image tensors should be scaled to -1 ... 1 and in bchw format
187
+ return batch[self.input_key]
188
+
189
+ @torch.no_grad()
190
+ def decode_first_stage(self, z):
191
+ z = 1.0 / self.scale_factor * z
192
+ is_video_input = False
193
+ bs = z.shape[0]
194
+ if z.dim() == 5:
195
+ is_video_input = True
196
+ # for video diffusion
197
+ z = rearrange(z, "b t c h w -> (b t) c h w")
198
+ n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
199
+
200
+ n_rounds = math.ceil(z.shape[0] / n_samples)
201
+ all_out = []
202
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
203
+ for n in range(n_rounds):
204
+ if isinstance(self.first_stage_model.decoder, VideoDecoder):
205
+ kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
206
+ else:
207
+ kwargs = {}
208
+ out = self.first_stage_model.decode(
209
+ z[n * n_samples : (n + 1) * n_samples], **kwargs
210
+ )
211
+ all_out.append(out)
212
+ out = torch.cat(all_out, dim=0)
213
+
214
+ if is_video_input:
215
+ out = rearrange(out, "(b t) c h w -> b t c h w", b=bs)
216
+
217
+ return out
218
+
219
+ @torch.no_grad()
220
+ def encode_first_stage(self, x):
221
+ if self.input_key == "latents":
222
+ return x
223
+
224
+ bs = x.shape[0]
225
+ is_video_input = False
226
+ if x.dim() == 5:
227
+ is_video_input = True
228
+ # for video diffusion
229
+ x = rearrange(x, "b t c h w -> (b t) c h w")
230
+ n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
231
+ n_rounds = math.ceil(x.shape[0] / n_samples)
232
+ all_out = []
233
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
234
+ for n in range(n_rounds):
235
+ out = self.first_stage_model.encode(
236
+ x[n * n_samples : (n + 1) * n_samples]
237
+ )
238
+ all_out.append(out)
239
+ z = torch.cat(all_out, dim=0)
240
+ z = self.scale_factor * z
241
+
242
+ # if is_video_input:
243
+ # z = rearrange(z, "(b t) c h w -> b t c h w", b=bs)
244
+
245
+ return z
246
+
247
+ def forward(self, x, batch):
248
+ loss, model_output = self.loss_fn(
249
+ self.model,
250
+ self.denoiser,
251
+ self.conditioner,
252
+ x,
253
+ batch,
254
+ return_model_output=True,
255
+ )
256
+ loss_mean = loss.mean()
257
+ loss_dict = {"loss": loss_mean, "model_output": model_output}
258
+ return loss_mean, loss_dict
259
+
260
+ def shared_step(self, batch: Dict) -> Any:
261
+ # TODO: move this shit to collate_fn in dataloader
262
+ # if "fps_id" in batch:
263
+ # batch["fps_id"] = flatten_for_video(batch["fps_id"])
264
+ # if "motion_bucket_id" in batch:
265
+ # batch["motion_bucket_id"] = flatten_for_video(batch["motion_bucket_id"])
266
+ # if "cond_aug" in batch:
267
+ # batch["cond_aug"] = flatten_for_video(batch["cond_aug"])
268
+ x = self.get_input(batch)
269
+ x = self.encode_first_stage(x)
270
+ # ## debug
271
+ # x_recon = self.decode_first_stage(x)
272
+ # video_frames_as_grid((batch["frames"][0] + 1.0) / 2.0, "./tmp/origin.jpg")
273
+ # video_frames_as_grid((x_recon[0] + 1.0) / 2.0, "./tmp/recon.jpg")
274
+ # ## debug
275
+ batch["global_step"] = self.global_step
276
+ loss, loss_dict = self(x, batch)
277
+ return loss, loss_dict
278
+
279
+ def training_step(self, batch, batch_idx):
280
+ loss, loss_dict = self.shared_step(batch)
281
+
282
+ with torch.no_grad():
283
+ if "model_output" in loss_dict:
284
+ if batch_idx % 100 == 0:
285
+ if isinstance(self.logger, WandbLogger):
286
+ model_output = loss_dict["model_output"].detach()[
287
+ : batch["num_video_frames"]
288
+ ]
289
+ recons = (
290
+ (self.decode_first_stage(model_output) + 1.0) / 2.0
291
+ ).clamp(0.0, 1.0)
292
+ recon_grid = make_grid(recons, nrow=4)
293
+ self.logger.log_image(
294
+ key=f"train/model_output_recon",
295
+ images=[recon_grid],
296
+ step=self.global_step,
297
+ )
298
+ del loss_dict["model_output"]
299
+
300
+ if torch.isnan(loss).any():
301
+ print("Nan detected")
302
+ loss = None
303
+
304
+ self.log_dict(
305
+ loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
306
+ )
307
+
308
+ self.log(
309
+ "global_step",
310
+ self.global_step,
311
+ prog_bar=True,
312
+ logger=True,
313
+ on_step=True,
314
+ on_epoch=False,
315
+ )
316
+
317
+ if self.scheduler_config is not None:
318
+ lr = self.optimizers().param_groups[0]["lr"]
319
+ self.log(
320
+ "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
321
+ )
322
+
323
+ return loss
324
+
325
+ def on_train_start(self, *args, **kwargs):
326
+ if self.sampler is None or self.loss_fn is None:
327
+ raise ValueError("Sampler and loss function need to be set for training.")
328
+
329
+ def on_train_batch_end(self, *args, **kwargs):
330
+ if self.use_ema:
331
+ self.model_ema(self.model)
332
+
333
+ @contextmanager
334
+ def ema_scope(self, context=None):
335
+ if self.use_ema:
336
+ self.model_ema.store(self.model.parameters())
337
+ self.model_ema.copy_to(self.model)
338
+ if context is not None:
339
+ print(f"{context}: Switched to EMA weights")
340
+ try:
341
+ yield None
342
+ finally:
343
+ if self.use_ema:
344
+ self.model_ema.restore(self.model.parameters())
345
+ if context is not None:
346
+ print(f"{context}: Restored training weights")
347
+
348
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
349
+ return get_obj_from_str(cfg["target"])(
350
+ params, lr=lr, **cfg.get("params", dict())
351
+ )
352
+
353
+ def configure_optimizers(self):
354
+ lr = self.learning_rate
355
+ params = list(self.model.parameters())
356
+ for embedder in self.conditioner.embedders:
357
+ if embedder.is_trainable:
358
+ params = params + list(embedder.parameters())
359
+ opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
360
+ if self.scheduler_config is not None:
361
+ scheduler = instantiate_from_config(self.scheduler_config)
362
+ print("Setting up LambdaLR scheduler...")
363
+ scheduler = [
364
+ {
365
+ "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
366
+ "interval": "step",
367
+ "frequency": 1,
368
+ }
369
+ ]
370
+ return [opt], scheduler
371
+ return opt
372
+
373
+ @torch.no_grad()
374
+ def sample(
375
+ self,
376
+ cond: Dict,
377
+ uc: Union[Dict, None] = None,
378
+ batch_size: int = 16,
379
+ shape: Union[None, Tuple, List] = None,
380
+ **kwargs,
381
+ ):
382
+ randn = torch.randn(batch_size, *shape).to(self.device)
383
+
384
+ denoiser = lambda input, sigma, c: self.denoiser(
385
+ self.model, input, sigma, c, **kwargs
386
+ )
387
+ samples = self.sampler(denoiser, randn, cond, uc=uc)
388
+ return samples
389
+
390
+ @torch.no_grad()
391
+ def log_conditionings(self, batch: Dict, n: int) -> Dict:
392
+ """
393
+ Defines heuristics to log different conditionings.
394
+ These can be lists of strings (text-to-image), tensors, ints, ...
395
+ """
396
+ image_h, image_w = batch[self.input_key].shape[-2:]
397
+ log = dict()
398
+
399
+ for embedder in self.conditioner.embedders:
400
+ if (
401
+ (self.log_keys is None) or (embedder.input_key in self.log_keys)
402
+ ) and not self.no_cond_log:
403
+ x = batch[embedder.input_key][:n]
404
+ if isinstance(x, torch.Tensor):
405
+ if x.dim() == 1:
406
+ # class-conditional, convert integer to string
407
+ x = [str(x[i].item()) for i in range(x.shape[0])]
408
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
409
+ elif x.dim() == 2:
410
+ # size and crop cond and the like
411
+ x = [
412
+ "x".join([str(xx) for xx in x[i].tolist()])
413
+ for i in range(x.shape[0])
414
+ ]
415
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
416
+ elif x.dim() == 4:
417
+ # image
418
+ xc = x
419
+ else:
420
+ raise NotImplementedError()
421
+ elif isinstance(x, (List, ListConfig)):
422
+ if isinstance(x[0], str):
423
+ # strings
424
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
425
+ else:
426
+ raise NotImplementedError()
427
+ else:
428
+ raise NotImplementedError()
429
+ log[embedder.input_key] = xc
430
+
431
+ return log
432
+
433
+ # for video diffusions will be logging frames of a video
434
+ @torch.no_grad()
435
+ def log_images(
436
+ self,
437
+ batch: Dict,
438
+ N: int = 1,
439
+ sample: bool = True,
440
+ ucg_keys: List[str] = None,
441
+ **kwargs,
442
+ ) -> Dict:
443
+ # # debug
444
+ # return {}
445
+ # # debug
446
+ assert "num_video_frames" in batch, "num_video_frames must be in batch"
447
+ num_video_frames = batch["num_video_frames"]
448
+ conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
449
+ conditioner_input_keys = []
450
+ for e in self.conditioner.embedders:
451
+ if e.input_key is not None:
452
+ conditioner_input_keys.append(e.input_key)
453
+ else:
454
+ conditioner_input_keys.extend(e.input_keys)
455
+ if ucg_keys:
456
+ assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
457
+ "Each defined ucg key for sampling must be in the provided conditioner input keys,"
458
+ f"but we have {ucg_keys} vs. {conditioner_input_keys}"
459
+ )
460
+ else:
461
+ ucg_keys = conditioner_input_keys
462
+ log = dict()
463
+
464
+ x = self.get_input(batch)
465
+
466
+ c, uc = self.conditioner.get_unconditional_conditioning(
467
+ batch,
468
+ force_uc_zero_embeddings=ucg_keys
469
+ if len(self.conditioner.embedders) > 0
470
+ else [],
471
+ )
472
+
473
+ sampling_kwargs = {"num_video_frames": num_video_frames}
474
+ n = min(x.shape[0] // num_video_frames, N)
475
+ sampling_kwargs["image_only_indicator"] = torch.cat(
476
+ [batch["image_only_indicator"][:n]] * 2
477
+ )
478
+
479
+ N = min(x.shape[0] // num_video_frames, N) * num_video_frames
480
+ x = x.to(self.device)[:N]
481
+ # log["inputs"] = rearrange(x, "(b t) c h w -> b c h (t w)", t=num_video_frames)
482
+ log["inputs"] = x
483
+ z = self.encode_first_stage(x)
484
+ recon = self.decode_first_stage(z)
485
+ # log["reconstructions"] = rearrange(
486
+ # recon, "(b t) c h w -> b c h (t w)", t=num_video_frames
487
+ # )
488
+ log["reconstructions"] = recon
489
+ log.update(self.log_conditionings(batch, N))
490
+ log["pixelnerf_rgb"] = c["rgb"]
491
+
492
+ for k in ["crossattn", "concat", "vector"]:
493
+ if k in c:
494
+ c[k] = c[k][:N]
495
+ uc[k] = uc[k][:N]
496
+
497
+ # for k in c:
498
+ # if isinstance(c[k], torch.Tensor):
499
+ # if k == "vector":
500
+ # end = N
501
+ # else:
502
+ # end = n
503
+ # c[k], uc[k] = map(lambda y: y[k][:end].to(self.device), (c, uc))
504
+
505
+ # # for k in c:
506
+ # # print(c[k].shape)
507
+
508
+ # breakpoint()
509
+ # for k in ["crossattn", "concat"]:
510
+ # c[k] = repeat(c[k], "b ... -> b t ...", t=num_video_frames)
511
+ # c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_video_frames)
512
+ # uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_video_frames)
513
+ # uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_video_frames)
514
+
515
+ # for k in c:
516
+ # print(c[k].shape)
517
+ if sample:
518
+ with self.ema_scope("Plotting"):
519
+ samples = self.sample(
520
+ c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
521
+ )
522
+ samples = self.decode_first_stage(samples)
523
+ log["samples"] = samples
524
+ return log
sgm/models/video_diffusion.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import math
3
+ from contextlib import contextmanager
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+
6
+ import pytorch_lightning as pl
7
+ from pytorch_lightning.loggers import WandbLogger
8
+ import torch
9
+ from omegaconf import ListConfig, OmegaConf
10
+ from safetensors.torch import load_file as load_safetensors
11
+ from torch.optim.lr_scheduler import LambdaLR
12
+ from torchvision.utils import make_grid
13
+ from einops import rearrange, repeat
14
+
15
+ from ..modules import UNCONDITIONAL_CONFIG
16
+ from ..modules.autoencoding.temporal_ae import VideoDecoder
17
+ from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
18
+ from ..modules.ema import LitEma
19
+ from ..modules.encoders.modules import VideoPredictionEmbedderWithEncoder
20
+ from ..util import (
21
+ default,
22
+ disabled_train,
23
+ get_obj_from_str,
24
+ instantiate_from_config,
25
+ log_txt_as_img,
26
+ video_frames_as_grid,
27
+ )
28
+
29
+
30
+ def flatten_for_video(input):
31
+ return input.flatten()
32
+
33
+
34
+ class DiffusionEngine(pl.LightningModule):
35
+ def __init__(
36
+ self,
37
+ network_config,
38
+ denoiser_config,
39
+ first_stage_config,
40
+ conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
41
+ sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
42
+ optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
43
+ scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
44
+ loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
45
+ network_wrapper: Union[None, str] = None,
46
+ ckpt_path: Union[None, str] = None,
47
+ use_ema: bool = False,
48
+ ema_decay_rate: float = 0.9999,
49
+ scale_factor: float = 1.0,
50
+ disable_first_stage_autocast=False,
51
+ input_key: str = "frames", # for video inputs
52
+ log_keys: Union[List, None] = None,
53
+ no_cond_log: bool = False,
54
+ compile_model: bool = False,
55
+ en_and_decode_n_samples_a_time: Optional[int] = None,
56
+ load_last_embedder: bool = False,
57
+ from_scratch: bool = False,
58
+ ):
59
+ super().__init__()
60
+ self.log_keys = log_keys
61
+ self.input_key = input_key
62
+ self.optimizer_config = default(
63
+ optimizer_config, {"target": "torch.optim.AdamW"}
64
+ )
65
+ model = instantiate_from_config(network_config)
66
+ self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
67
+ model, compile_model=compile_model
68
+ )
69
+
70
+ self.denoiser = instantiate_from_config(denoiser_config)
71
+ self.sampler = (
72
+ instantiate_from_config(sampler_config)
73
+ if sampler_config is not None
74
+ else None
75
+ )
76
+ self.conditioner = instantiate_from_config(
77
+ default(conditioner_config, UNCONDITIONAL_CONFIG)
78
+ )
79
+ self.scheduler_config = scheduler_config
80
+ self._init_first_stage(first_stage_config)
81
+
82
+ self.loss_fn = (
83
+ instantiate_from_config(loss_fn_config)
84
+ if loss_fn_config is not None
85
+ else None
86
+ )
87
+
88
+ self.use_ema = use_ema
89
+ if self.use_ema:
90
+ self.model_ema = LitEma(self.model, decay=ema_decay_rate)
91
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
92
+
93
+ self.scale_factor = scale_factor
94
+ self.disable_first_stage_autocast = disable_first_stage_autocast
95
+ self.no_cond_log = no_cond_log
96
+
97
+ self.load_last_embedder = load_last_embedder
98
+ if ckpt_path is not None:
99
+ self.init_from_ckpt(ckpt_path, from_scratch)
100
+
101
+ self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
102
+
103
+ def _load_last_embedder(self, original_state_dict):
104
+ original_module_name = "conditioner.embedders.3"
105
+ state_dict = dict()
106
+ for k, v in original_state_dict.items():
107
+ m = re.match(rf"^{original_module_name}\.(.*)$", k)
108
+ if m is None:
109
+ continue
110
+ state_dict[m.group(1)] = v
111
+
112
+ idx = -1
113
+ for i in range(len(self.conditioner.embedders)):
114
+ if isinstance(
115
+ self.conditioner.embedders[i], VideoPredictionEmbedderWithEncoder
116
+ ):
117
+ idx = i
118
+
119
+ print(f"Embedder [{idx}] is the frame encoder, make sure this is expected")
120
+
121
+ self.conditioner.embedders[idx].load_state_dict(state_dict)
122
+
123
+ def init_from_ckpt(
124
+ self,
125
+ path: str,
126
+ from_scratch: bool = False,
127
+ ) -> None:
128
+ if path.endswith("ckpt"):
129
+ sd = torch.load(path, map_location="cpu")["state_dict"]
130
+ elif path.endswith("safetensors"):
131
+ sd = load_safetensors(path)
132
+ else:
133
+ raise NotImplementedError
134
+
135
+ deleted_keys = []
136
+ for k, v in self.state_dict().items():
137
+ # resolve shape dismatch
138
+ if k in sd:
139
+ if v.shape != sd[k].shape:
140
+ del sd[k]
141
+ deleted_keys.append(k)
142
+
143
+ if from_scratch:
144
+ new_sd = {}
145
+ for k in sd:
146
+ if "first_stage_model" in k:
147
+ new_sd[k] = sd[k]
148
+ sd = new_sd
149
+ print(sd.keys())
150
+
151
+ if len(deleted_keys) > 0:
152
+ print(f"Deleted Keys: {deleted_keys}")
153
+
154
+ missing, unexpected = self.load_state_dict(sd, strict=False)
155
+ print(
156
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
157
+ )
158
+ if len(missing) > 0:
159
+ print(f"Missing Keys: {missing}")
160
+ if len(unexpected) > 0:
161
+ print(f"Unexpected Keys: {unexpected}")
162
+ if len(deleted_keys) > 0:
163
+ print(f"Deleted Keys: {deleted_keys}")
164
+
165
+ if (len(missing) > 0 or len(unexpected) > 0) and self.load_last_embedder:
166
+ # means we are loading from a checkpoint that has the old embedder (motion bucket id and fps id)
167
+ print("Modified embedder to support 3d spiral video inputs")
168
+ self._load_last_embedder(sd)
169
+
170
+ def _init_first_stage(self, config):
171
+ model = instantiate_from_config(config).eval()
172
+ model.train = disabled_train
173
+ for param in model.parameters():
174
+ param.requires_grad = False
175
+ self.first_stage_model = model
176
+
177
+ def get_input(self, batch):
178
+ # assuming unified data format, dataloader returns a dict.
179
+ # image tensors should be scaled to -1 ... 1 and in bchw format
180
+ return batch[self.input_key]
181
+
182
+ @torch.no_grad()
183
+ def decode_first_stage(self, z):
184
+ z = 1.0 / self.scale_factor * z
185
+ is_video_input = False
186
+ bs = z.shape[0]
187
+ if z.dim() == 5:
188
+ is_video_input = True
189
+ # for video diffusion
190
+ z = rearrange(z, "b t c h w -> (b t) c h w")
191
+ n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
192
+
193
+ n_rounds = math.ceil(z.shape[0] / n_samples)
194
+ all_out = []
195
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
196
+ for n in range(n_rounds):
197
+ if isinstance(self.first_stage_model.decoder, VideoDecoder):
198
+ kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
199
+ else:
200
+ kwargs = {}
201
+ out = self.first_stage_model.decode(
202
+ z[n * n_samples : (n + 1) * n_samples], **kwargs
203
+ )
204
+ all_out.append(out)
205
+ out = torch.cat(all_out, dim=0)
206
+
207
+ if is_video_input:
208
+ out = rearrange(out, "(b t) c h w -> b t c h w", b=bs)
209
+
210
+ return out
211
+
212
+ @torch.no_grad()
213
+ def encode_first_stage(self, x):
214
+ if self.input_key == "latents":
215
+ return x * self.scale_factor
216
+
217
+ bs = x.shape[0]
218
+ is_video_input = False
219
+ if x.dim() == 5:
220
+ is_video_input = True
221
+ # for video diffusion
222
+ x = rearrange(x, "b t c h w -> (b t) c h w")
223
+ n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
224
+ n_rounds = math.ceil(x.shape[0] / n_samples)
225
+ all_out = []
226
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
227
+ for n in range(n_rounds):
228
+ out = self.first_stage_model.encode(
229
+ x[n * n_samples : (n + 1) * n_samples]
230
+ )
231
+ all_out.append(out)
232
+ z = torch.cat(all_out, dim=0)
233
+ z = self.scale_factor * z
234
+
235
+ # if is_video_input:
236
+ # z = rearrange(z, "(b t) c h w -> b t c h w", b=bs)
237
+
238
+ return z
239
+
240
+ def forward(self, x, batch):
241
+ loss, model_output = self.loss_fn(
242
+ self.model,
243
+ self.denoiser,
244
+ self.conditioner,
245
+ x,
246
+ batch,
247
+ return_model_output=True,
248
+ )
249
+ loss_mean = loss.mean()
250
+ loss_dict = {"loss": loss_mean, "model_output": model_output}
251
+ return loss_mean, loss_dict
252
+
253
+ def shared_step(self, batch: Dict) -> Any:
254
+ # TODO: move this shit to collate_fn in dataloader
255
+ # if "fps_id" in batch:
256
+ # batch["fps_id"] = flatten_for_video(batch["fps_id"])
257
+ # if "motion_bucket_id" in batch:
258
+ # batch["motion_bucket_id"] = flatten_for_video(batch["motion_bucket_id"])
259
+ # if "cond_aug" in batch:
260
+ # batch["cond_aug"] = flatten_for_video(batch["cond_aug"])
261
+ x = self.get_input(batch)
262
+ x = self.encode_first_stage(x)
263
+ # ## debug
264
+ # x_recon = self.decode_first_stage(x)
265
+ # video_frames_as_grid((batch["frames"][0] + 1.0) / 2.0, "./tmp/origin.jpg")
266
+ # video_frames_as_grid((x_recon[0] + 1.0) / 2.0, "./tmp/recon.jpg")
267
+ # ## debug
268
+ batch["global_step"] = self.global_step
269
+ # breakpoint()
270
+ loss, loss_dict = self(x, batch)
271
+ return loss, loss_dict
272
+
273
+ def training_step(self, batch, batch_idx):
274
+ loss, loss_dict = self.shared_step(batch)
275
+
276
+ with torch.no_grad():
277
+ if "model_output" in loss_dict:
278
+ if batch_idx % 100 == 0:
279
+ if isinstance(self.logger, WandbLogger):
280
+ model_output = loss_dict["model_output"].detach()[
281
+ : batch["num_video_frames"]
282
+ ]
283
+ recons = (
284
+ (self.decode_first_stage(model_output) + 1.0) / 2.0
285
+ ).clamp(0.0, 1.0)
286
+ recon_grid = make_grid(recons, nrow=4)
287
+ self.logger.log_image(
288
+ key=f"train/model_output_recon",
289
+ images=[recon_grid],
290
+ step=self.global_step,
291
+ )
292
+ del loss_dict["model_output"]
293
+
294
+ self.log_dict(
295
+ loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
296
+ )
297
+
298
+ self.log(
299
+ "global_step",
300
+ self.global_step,
301
+ prog_bar=True,
302
+ logger=True,
303
+ on_step=True,
304
+ on_epoch=False,
305
+ )
306
+
307
+ if self.scheduler_config is not None:
308
+ lr = self.optimizers().param_groups[0]["lr"]
309
+ self.log(
310
+ "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
311
+ )
312
+
313
+ return loss
314
+
315
+ def on_train_start(self, *args, **kwargs):
316
+ if self.sampler is None or self.loss_fn is None:
317
+ raise ValueError("Sampler and loss function need to be set for training.")
318
+
319
+ def on_train_batch_end(self, *args, **kwargs):
320
+ if self.use_ema:
321
+ self.model_ema(self.model)
322
+
323
+ @contextmanager
324
+ def ema_scope(self, context=None):
325
+ if self.use_ema:
326
+ self.model_ema.store(self.model.parameters())
327
+ self.model_ema.copy_to(self.model)
328
+ if context is not None:
329
+ print(f"{context}: Switched to EMA weights")
330
+ try:
331
+ yield None
332
+ finally:
333
+ if self.use_ema:
334
+ self.model_ema.restore(self.model.parameters())
335
+ if context is not None:
336
+ print(f"{context}: Restored training weights")
337
+
338
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
339
+ return get_obj_from_str(cfg["target"])(
340
+ params, lr=lr, **cfg.get("params", dict())
341
+ )
342
+
343
+ def configure_optimizers(self):
344
+ lr = self.learning_rate
345
+ params = list(self.model.parameters())
346
+ for embedder in self.conditioner.embedders:
347
+ if embedder.is_trainable:
348
+ params = params + list(embedder.parameters())
349
+ opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
350
+ if self.scheduler_config is not None:
351
+ scheduler = instantiate_from_config(self.scheduler_config)
352
+ print("Setting up LambdaLR scheduler...")
353
+ scheduler = [
354
+ {
355
+ "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
356
+ "interval": "step",
357
+ "frequency": 1,
358
+ }
359
+ ]
360
+ return [opt], scheduler
361
+ return opt
362
+
363
+ @torch.no_grad()
364
+ def sample(
365
+ self,
366
+ cond: Dict,
367
+ uc: Union[Dict, None] = None,
368
+ batch_size: int = 16,
369
+ shape: Union[None, Tuple, List] = None,
370
+ **kwargs,
371
+ ):
372
+ randn = torch.randn(batch_size, *shape).to(self.device)
373
+
374
+ denoiser = lambda input, sigma, c: self.denoiser(
375
+ self.model, input, sigma, c, **kwargs
376
+ )
377
+ samples = self.sampler(denoiser, randn, cond, uc=uc)
378
+ return samples
379
+
380
+ @torch.no_grad()
381
+ def log_conditionings(self, batch: Dict, n: int) -> Dict:
382
+ """
383
+ Defines heuristics to log different conditionings.
384
+ These can be lists of strings (text-to-image), tensors, ints, ...
385
+ """
386
+ image_h, image_w = batch[self.input_key].shape[-2:]
387
+ log = dict()
388
+
389
+ for embedder in self.conditioner.embedders:
390
+ if (
391
+ (self.log_keys is None) or (embedder.input_key in self.log_keys)
392
+ ) and not self.no_cond_log:
393
+ x = batch[embedder.input_key][:n]
394
+ if isinstance(x, torch.Tensor):
395
+ if x.dim() == 1:
396
+ # class-conditional, convert integer to string
397
+ x = [str(x[i].item()) for i in range(x.shape[0])]
398
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
399
+ elif x.dim() == 2:
400
+ # size and crop cond and the like
401
+ x = [
402
+ "x".join([str(xx) for xx in x[i].tolist()])
403
+ for i in range(x.shape[0])
404
+ ]
405
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
406
+ elif x.dim() == 4:
407
+ # image
408
+ xc = x
409
+ else:
410
+ pass
411
+ # breakpoint()
412
+ # raise NotImplementedError()
413
+ elif isinstance(x, (List, ListConfig)):
414
+ if isinstance(x[0], str):
415
+ # strings
416
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
417
+ else:
418
+ raise NotImplementedError()
419
+ else:
420
+ raise NotImplementedError()
421
+ log[embedder.input_key] = xc
422
+ return log
423
+
424
+ # for video diffusions will be logging frames of a video
425
+ @torch.no_grad()
426
+ def log_images(
427
+ self,
428
+ batch: Dict,
429
+ N: int = 1,
430
+ sample: bool = True,
431
+ ucg_keys: List[str] = None,
432
+ **kwargs,
433
+ ) -> Dict:
434
+ # # debug
435
+ # return {}
436
+ # # debug
437
+ assert "num_video_frames" in batch, "num_video_frames must be in batch"
438
+ num_video_frames = batch["num_video_frames"]
439
+ conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
440
+ if ucg_keys:
441
+ assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
442
+ "Each defined ucg key for sampling must be in the provided conditioner input keys,"
443
+ f"but we have {ucg_keys} vs. {conditioner_input_keys}"
444
+ )
445
+ else:
446
+ ucg_keys = conditioner_input_keys
447
+ log = dict()
448
+
449
+ x = self.get_input(batch)
450
+
451
+ c, uc = self.conditioner.get_unconditional_conditioning(
452
+ batch,
453
+ force_uc_zero_embeddings=ucg_keys
454
+ if len(self.conditioner.embedders) > 0
455
+ else [],
456
+ )
457
+
458
+ sampling_kwargs = {"num_video_frames": num_video_frames}
459
+ n = min(x.shape[0] // num_video_frames, N)
460
+ sampling_kwargs["image_only_indicator"] = torch.cat(
461
+ [batch["image_only_indicator"][:n]] * 2
462
+ )
463
+
464
+ N = min(x.shape[0] // num_video_frames, N) * num_video_frames
465
+ x = x.to(self.device)[:N]
466
+ # log["inputs"] = rearrange(x, "(b t) c h w -> b c h (t w)", t=num_video_frames)
467
+ if self.input_key != "latents":
468
+ log["inputs"] = x
469
+ z = self.encode_first_stage(x)
470
+ recon = self.decode_first_stage(z)
471
+ # log["reconstructions"] = rearrange(
472
+ # recon, "(b t) c h w -> b c h (t w)", t=num_video_frames
473
+ # )
474
+ log["reconstructions"] = recon
475
+ log.update(self.log_conditionings(batch, N))
476
+
477
+ for k in c:
478
+ if isinstance(c[k], torch.Tensor):
479
+ if k == "vector":
480
+ end = N
481
+ else:
482
+ end = n
483
+ c[k], uc[k] = map(lambda y: y[k][:end].to(self.device), (c, uc))
484
+
485
+ # for k in c:
486
+ # print(c[k].shape)
487
+
488
+ for k in ["crossattn", "concat"]:
489
+ c[k] = repeat(c[k], "b ... -> b t ...", t=num_video_frames)
490
+ c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_video_frames)
491
+ uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_video_frames)
492
+ uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_video_frames)
493
+
494
+ # for k in c:
495
+ # print(c[k].shape)
496
+ if sample:
497
+ with self.ema_scope("Plotting"):
498
+ samples = self.sample(
499
+ c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
500
+ )
501
+ samples = self.decode_first_stage(samples)
502
+ log["samples"] = samples
503
+ return log
sgm/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .encoders.modules import GeneralConditioner, ExtraConditioner
2
+
3
+ UNCONDITIONAL_CONFIG = {
4
+ "target": "sgm.modules.GeneralConditioner",
5
+ "params": {"emb_models": []},
6
+ }
sgm/modules/attention.py ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from inspect import isfunction
4
+ from typing import Any, Optional
5
+ from functools import partial
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+ from packaging import version
11
+ from torch import nn
12
+
13
+ # from torch.utils.checkpoint import checkpoint
14
+
15
+ checkpoint = partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
16
+
17
+
18
+ logpy = logging.getLogger(__name__)
19
+
20
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
21
+ SDP_IS_AVAILABLE = True
22
+ from torch.backends.cuda import SDPBackend, sdp_kernel
23
+
24
+ BACKEND_MAP = {
25
+ SDPBackend.MATH: {
26
+ "enable_math": True,
27
+ "enable_flash": False,
28
+ "enable_mem_efficient": False,
29
+ },
30
+ SDPBackend.FLASH_ATTENTION: {
31
+ "enable_math": False,
32
+ "enable_flash": True,
33
+ "enable_mem_efficient": False,
34
+ },
35
+ SDPBackend.EFFICIENT_ATTENTION: {
36
+ "enable_math": False,
37
+ "enable_flash": False,
38
+ "enable_mem_efficient": True,
39
+ },
40
+ None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
41
+ }
42
+ else:
43
+ from contextlib import nullcontext
44
+
45
+ SDP_IS_AVAILABLE = False
46
+ sdp_kernel = nullcontext
47
+ BACKEND_MAP = {}
48
+ logpy.warn(
49
+ f"No SDP backend available, likely because you are running in pytorch "
50
+ f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
51
+ f"You might want to consider upgrading."
52
+ )
53
+
54
+ try:
55
+ import xformers
56
+ import xformers.ops
57
+
58
+ XFORMERS_IS_AVAILABLE = True
59
+ except:
60
+ XFORMERS_IS_AVAILABLE = False
61
+ logpy.warn("no module 'xformers'. Processing without...")
62
+
63
+ # from .diffusionmodules.util import mixed_checkpoint as checkpoint
64
+
65
+
66
+ def exists(val):
67
+ return val is not None
68
+
69
+
70
+ def uniq(arr):
71
+ return {el: True for el in arr}.keys()
72
+
73
+
74
+ def default(val, d):
75
+ if exists(val):
76
+ return val
77
+ return d() if isfunction(d) else d
78
+
79
+
80
+ def max_neg_value(t):
81
+ return -torch.finfo(t.dtype).max
82
+
83
+
84
+ def init_(tensor):
85
+ dim = tensor.shape[-1]
86
+ std = 1 / math.sqrt(dim)
87
+ tensor.uniform_(-std, std)
88
+ return tensor
89
+
90
+
91
+ # feedforward
92
+ class GEGLU(nn.Module):
93
+ def __init__(self, dim_in, dim_out):
94
+ super().__init__()
95
+ self.proj = nn.Linear(dim_in, dim_out * 2)
96
+
97
+ def forward(self, x):
98
+ x, gate = self.proj(x).chunk(2, dim=-1)
99
+ return x * F.gelu(gate)
100
+
101
+
102
+ class FeedForward(nn.Module):
103
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
104
+ super().__init__()
105
+ inner_dim = int(dim * mult)
106
+ dim_out = default(dim_out, dim)
107
+ project_in = (
108
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
109
+ if not glu
110
+ else GEGLU(dim, inner_dim)
111
+ )
112
+
113
+ self.net = nn.Sequential(
114
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
115
+ )
116
+
117
+ def forward(self, x):
118
+ return self.net(x)
119
+
120
+
121
+ def zero_module(module):
122
+ """
123
+ Zero out the parameters of a module and return it.
124
+ """
125
+ for p in module.parameters():
126
+ p.detach().zero_()
127
+ return module
128
+
129
+
130
+ def Normalize(in_channels):
131
+ return torch.nn.GroupNorm(
132
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
133
+ )
134
+
135
+
136
+ class LinearAttention(nn.Module):
137
+ def __init__(self, dim, heads=4, dim_head=32):
138
+ super().__init__()
139
+ self.heads = heads
140
+ hidden_dim = dim_head * heads
141
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
142
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
143
+
144
+ def forward(self, x):
145
+ b, c, h, w = x.shape
146
+ qkv = self.to_qkv(x)
147
+ q, k, v = rearrange(
148
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
149
+ )
150
+ k = k.softmax(dim=-1)
151
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
152
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
153
+ out = rearrange(
154
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
155
+ )
156
+ return self.to_out(out)
157
+
158
+
159
+ class SelfAttention(nn.Module):
160
+ ATTENTION_MODES = ("xformers", "torch", "math")
161
+
162
+ def __init__(
163
+ self,
164
+ dim: int,
165
+ num_heads: int = 8,
166
+ qkv_bias: bool = False,
167
+ qk_scale: Optional[float] = None,
168
+ attn_drop: float = 0.0,
169
+ proj_drop: float = 0.0,
170
+ attn_mode: str = "xformers",
171
+ ):
172
+ super().__init__()
173
+ self.num_heads = num_heads
174
+ head_dim = dim // num_heads
175
+ self.scale = qk_scale or head_dim**-0.5
176
+
177
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
178
+ self.attn_drop = nn.Dropout(attn_drop)
179
+ self.proj = nn.Linear(dim, dim)
180
+ self.proj_drop = nn.Dropout(proj_drop)
181
+ assert attn_mode in self.ATTENTION_MODES
182
+ self.attn_mode = attn_mode
183
+
184
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
185
+ B, L, C = x.shape
186
+
187
+ qkv = self.qkv(x)
188
+ if self.attn_mode == "torch":
189
+ qkv = rearrange(
190
+ qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
191
+ ).float()
192
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
193
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
194
+ x = rearrange(x, "B H L D -> B L (H D)")
195
+ elif self.attn_mode == "xformers":
196
+ qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
197
+ q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
198
+ x = xformers.ops.memory_efficient_attention(q, k, v)
199
+ x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads)
200
+ elif self.attn_mode == "math":
201
+ qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
202
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
203
+ attn = (q @ k.transpose(-2, -1)) * self.scale
204
+ attn = attn.softmax(dim=-1)
205
+ attn = self.attn_drop(attn)
206
+ x = (attn @ v).transpose(1, 2).reshape(B, L, C)
207
+ else:
208
+ raise NotImplemented
209
+
210
+ x = self.proj(x)
211
+ x = self.proj_drop(x)
212
+ return x
213
+
214
+
215
+ class SpatialSelfAttention(nn.Module):
216
+ def __init__(self, in_channels):
217
+ super().__init__()
218
+ self.in_channels = in_channels
219
+
220
+ self.norm = Normalize(in_channels)
221
+ self.q = torch.nn.Conv2d(
222
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
223
+ )
224
+ self.k = torch.nn.Conv2d(
225
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
226
+ )
227
+ self.v = torch.nn.Conv2d(
228
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
229
+ )
230
+ self.proj_out = torch.nn.Conv2d(
231
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
232
+ )
233
+
234
+ def forward(self, x):
235
+ h_ = x
236
+ h_ = self.norm(h_)
237
+ q = self.q(h_)
238
+ k = self.k(h_)
239
+ v = self.v(h_)
240
+
241
+ # compute attention
242
+ b, c, h, w = q.shape
243
+ q = rearrange(q, "b c h w -> b (h w) c")
244
+ k = rearrange(k, "b c h w -> b c (h w)")
245
+ w_ = torch.einsum("bij,bjk->bik", q, k)
246
+
247
+ w_ = w_ * (int(c) ** (-0.5))
248
+ w_ = torch.nn.functional.softmax(w_, dim=2)
249
+
250
+ # attend to values
251
+ v = rearrange(v, "b c h w -> b c (h w)")
252
+ w_ = rearrange(w_, "b i j -> b j i")
253
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
254
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
255
+ h_ = self.proj_out(h_)
256
+
257
+ return x + h_
258
+
259
+
260
+ class CrossAttention(nn.Module):
261
+ def __init__(
262
+ self,
263
+ query_dim,
264
+ context_dim=None,
265
+ heads=8,
266
+ dim_head=64,
267
+ dropout=0.0,
268
+ backend=None,
269
+ ):
270
+ super().__init__()
271
+ inner_dim = dim_head * heads
272
+ context_dim = default(context_dim, query_dim)
273
+
274
+ self.scale = dim_head**-0.5
275
+ self.heads = heads
276
+
277
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
278
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
279
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
280
+
281
+ self.to_out = nn.Sequential(
282
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
283
+ )
284
+ self.backend = backend
285
+
286
+ def forward(
287
+ self,
288
+ x,
289
+ context=None,
290
+ mask=None,
291
+ additional_tokens=None,
292
+ n_times_crossframe_attn_in_self=0,
293
+ ):
294
+ h = self.heads
295
+
296
+ if additional_tokens is not None:
297
+ # get the number of masked tokens at the beginning of the output sequence
298
+ n_tokens_to_mask = additional_tokens.shape[1]
299
+ # add additional token
300
+ x = torch.cat([additional_tokens, x], dim=1)
301
+
302
+ q = self.to_q(x)
303
+ context = default(context, x)
304
+ k = self.to_k(context)
305
+ v = self.to_v(context)
306
+
307
+ if n_times_crossframe_attn_in_self:
308
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
309
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
310
+ n_cp = x.shape[0] // n_times_crossframe_attn_in_self
311
+ k = repeat(
312
+ k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
313
+ )
314
+ v = repeat(
315
+ v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
316
+ )
317
+
318
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
319
+
320
+ ## old
321
+ """
322
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
323
+ del q, k
324
+
325
+ if exists(mask):
326
+ mask = rearrange(mask, 'b ... -> b (...)')
327
+ max_neg_value = -torch.finfo(sim.dtype).max
328
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
329
+ sim.masked_fill_(~mask, max_neg_value)
330
+
331
+ # attention, what we cannot get enough of
332
+ sim = sim.softmax(dim=-1)
333
+
334
+ out = einsum('b i j, b j d -> b i d', sim, v)
335
+ """
336
+ ## new
337
+ with sdp_kernel(**BACKEND_MAP[self.backend]):
338
+ # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
339
+ out = F.scaled_dot_product_attention(
340
+ q, k, v, attn_mask=mask
341
+ ) # scale is dim_head ** -0.5 per default
342
+
343
+ del q, k, v
344
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
345
+
346
+ if additional_tokens is not None:
347
+ # remove additional token
348
+ out = out[:, n_tokens_to_mask:]
349
+ return self.to_out(out)
350
+
351
+
352
+ class MemoryEfficientCrossAttention(nn.Module):
353
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
354
+ def __init__(
355
+ self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
356
+ ):
357
+ super().__init__()
358
+ logpy.debug(
359
+ f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
360
+ f"context_dim is {context_dim} and using {heads} heads with a "
361
+ f"dimension of {dim_head}."
362
+ )
363
+ inner_dim = dim_head * heads
364
+ context_dim = default(context_dim, query_dim)
365
+
366
+ self.heads = heads
367
+ self.dim_head = dim_head
368
+
369
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
370
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
371
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
372
+
373
+ self.to_out = nn.Sequential(
374
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
375
+ )
376
+ self.attention_op: Optional[Any] = None
377
+
378
+ def forward(
379
+ self,
380
+ x,
381
+ context=None,
382
+ mask=None,
383
+ additional_tokens=None,
384
+ n_times_crossframe_attn_in_self=0,
385
+ ):
386
+ if additional_tokens is not None:
387
+ # get the number of masked tokens at the beginning of the output sequence
388
+ n_tokens_to_mask = additional_tokens.shape[1]
389
+ # add additional token
390
+ x = torch.cat([additional_tokens, x], dim=1)
391
+ q = self.to_q(x)
392
+ context = default(context, x)
393
+ k = self.to_k(context)
394
+ v = self.to_v(context)
395
+
396
+ if n_times_crossframe_attn_in_self:
397
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
398
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
399
+ # n_cp = x.shape[0]//n_times_crossframe_attn_in_self
400
+ k = repeat(
401
+ k[::n_times_crossframe_attn_in_self],
402
+ "b ... -> (b n) ...",
403
+ n=n_times_crossframe_attn_in_self,
404
+ )
405
+ v = repeat(
406
+ v[::n_times_crossframe_attn_in_self],
407
+ "b ... -> (b n) ...",
408
+ n=n_times_crossframe_attn_in_self,
409
+ )
410
+
411
+ b, _, _ = q.shape
412
+ q, k, v = map(
413
+ lambda t: t.unsqueeze(3)
414
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
415
+ .permute(0, 2, 1, 3)
416
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
417
+ .contiguous(),
418
+ (q, k, v),
419
+ )
420
+
421
+ # actually compute the attention, what we cannot get enough of
422
+ if version.parse(xformers.__version__) >= version.parse("0.0.21"):
423
+ # NOTE: workaround for
424
+ # https://github.com/facebookresearch/xformers/issues/845
425
+ max_bs = 32768
426
+ N = q.shape[0]
427
+ n_batches = math.ceil(N / max_bs)
428
+ out = list()
429
+ for i_batch in range(n_batches):
430
+ batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)
431
+ out.append(
432
+ xformers.ops.memory_efficient_attention(
433
+ q[batch],
434
+ k[batch],
435
+ v[batch],
436
+ attn_bias=None,
437
+ op=self.attention_op,
438
+ )
439
+ )
440
+ out = torch.cat(out, 0)
441
+ else:
442
+ out = xformers.ops.memory_efficient_attention(
443
+ q, k, v, attn_bias=None, op=self.attention_op
444
+ )
445
+
446
+ # TODO: Use this directly in the attention operation, as a bias
447
+ if exists(mask):
448
+ raise NotImplementedError
449
+ out = (
450
+ out.unsqueeze(0)
451
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
452
+ .permute(0, 2, 1, 3)
453
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
454
+ )
455
+ if additional_tokens is not None:
456
+ # remove additional token
457
+ out = out[:, n_tokens_to_mask:]
458
+ return self.to_out(out)
459
+
460
+
461
+ class BasicTransformerBlock(nn.Module):
462
+ ATTENTION_MODES = {
463
+ "softmax": CrossAttention, # vanilla attention
464
+ "softmax-xformers": MemoryEfficientCrossAttention, # ampere
465
+ }
466
+
467
+ def __init__(
468
+ self,
469
+ dim,
470
+ n_heads,
471
+ d_head,
472
+ dropout=0.0,
473
+ context_dim=None,
474
+ gated_ff=True,
475
+ checkpoint=True,
476
+ disable_self_attn=False,
477
+ attn_mode="softmax",
478
+ sdp_backend=None,
479
+ ):
480
+ super().__init__()
481
+ assert attn_mode in self.ATTENTION_MODES
482
+ if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
483
+ logpy.warn(
484
+ f"Attention mode '{attn_mode}' is not available. Falling "
485
+ f"back to native attention. This is not a problem in "
486
+ f"Pytorch >= 2.0. FYI, you are running with PyTorch "
487
+ f"version {torch.__version__}."
488
+ )
489
+ attn_mode = "softmax"
490
+ elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
491
+ logpy.warn(
492
+ "We do not support vanilla attention anymore, as it is too "
493
+ "expensive. Sorry."
494
+ )
495
+ if not XFORMERS_IS_AVAILABLE:
496
+ assert (
497
+ False
498
+ ), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
499
+ else:
500
+ logpy.info("Falling back to xformers efficient attention.")
501
+ attn_mode = "softmax-xformers"
502
+ attn_cls = self.ATTENTION_MODES[attn_mode]
503
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
504
+ assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
505
+ else:
506
+ assert sdp_backend is None
507
+ self.disable_self_attn = disable_self_attn
508
+ self.attn1 = attn_cls(
509
+ query_dim=dim,
510
+ heads=n_heads,
511
+ dim_head=d_head,
512
+ dropout=dropout,
513
+ context_dim=context_dim if self.disable_self_attn else None,
514
+ backend=sdp_backend,
515
+ ) # is a self-attention if not self.disable_self_attn
516
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
517
+ self.attn2 = attn_cls(
518
+ query_dim=dim,
519
+ context_dim=context_dim,
520
+ heads=n_heads,
521
+ dim_head=d_head,
522
+ dropout=dropout,
523
+ backend=sdp_backend,
524
+ ) # is self-attn if context is none
525
+ self.norm1 = nn.LayerNorm(dim)
526
+ self.norm2 = nn.LayerNorm(dim)
527
+ self.norm3 = nn.LayerNorm(dim)
528
+ self.checkpoint = checkpoint
529
+ if self.checkpoint:
530
+ logpy.debug(f"{self.__class__.__name__} is using checkpointing")
531
+
532
+ def forward(
533
+ self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
534
+ ):
535
+ kwargs = {"x": x}
536
+
537
+ if context is not None:
538
+ kwargs.update({"context": context})
539
+
540
+ if additional_tokens is not None:
541
+ kwargs.update({"additional_tokens": additional_tokens})
542
+
543
+ if n_times_crossframe_attn_in_self:
544
+ kwargs.update(
545
+ {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
546
+ )
547
+
548
+ # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
549
+ if self.checkpoint:
550
+ # inputs = {"x": x, "context": context}
551
+ return checkpoint(self._forward, x, context)
552
+ # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
553
+ else:
554
+ return self._forward(**kwargs)
555
+
556
+ def _forward(
557
+ self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
558
+ ):
559
+ x = (
560
+ self.attn1(
561
+ self.norm1(x),
562
+ context=context if self.disable_self_attn else None,
563
+ additional_tokens=additional_tokens,
564
+ n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
565
+ if not self.disable_self_attn
566
+ else 0,
567
+ )
568
+ + x
569
+ )
570
+ x = (
571
+ self.attn2(
572
+ self.norm2(x), context=context, additional_tokens=additional_tokens
573
+ )
574
+ + x
575
+ )
576
+ x = self.ff(self.norm3(x)) + x
577
+ return x
578
+
579
+
580
+ class BasicTransformerSingleLayerBlock(nn.Module):
581
+ ATTENTION_MODES = {
582
+ "softmax": CrossAttention, # vanilla attention
583
+ "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
584
+ # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
585
+ }
586
+
587
+ def __init__(
588
+ self,
589
+ dim,
590
+ n_heads,
591
+ d_head,
592
+ dropout=0.0,
593
+ context_dim=None,
594
+ gated_ff=True,
595
+ checkpoint=True,
596
+ attn_mode="softmax",
597
+ ):
598
+ super().__init__()
599
+ assert attn_mode in self.ATTENTION_MODES
600
+ attn_cls = self.ATTENTION_MODES[attn_mode]
601
+ self.attn1 = attn_cls(
602
+ query_dim=dim,
603
+ heads=n_heads,
604
+ dim_head=d_head,
605
+ dropout=dropout,
606
+ context_dim=context_dim,
607
+ )
608
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
609
+ self.norm1 = nn.LayerNorm(dim)
610
+ self.norm2 = nn.LayerNorm(dim)
611
+ self.checkpoint = checkpoint
612
+
613
+ def forward(self, x, context=None):
614
+ # inputs = {"x": x, "context": context}
615
+ # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
616
+ return checkpoint(self._forward, x, context)
617
+
618
+ def _forward(self, x, context=None):
619
+ x = self.attn1(self.norm1(x), context=context) + x
620
+ x = self.ff(self.norm2(x)) + x
621
+ return x
622
+
623
+
624
+ class SpatialTransformer(nn.Module):
625
+ """
626
+ Transformer block for image-like data.
627
+ First, project the input (aka embedding)
628
+ and reshape to b, t, d.
629
+ Then apply standard transformer action.
630
+ Finally, reshape to image
631
+ NEW: use_linear for more efficiency instead of the 1x1 convs
632
+ """
633
+
634
+ def __init__(
635
+ self,
636
+ in_channels,
637
+ n_heads,
638
+ d_head,
639
+ depth=1,
640
+ dropout=0.0,
641
+ context_dim=None,
642
+ disable_self_attn=False,
643
+ use_linear=False,
644
+ attn_type="softmax",
645
+ use_checkpoint=True,
646
+ # sdp_backend=SDPBackend.FLASH_ATTENTION
647
+ sdp_backend=None,
648
+ ):
649
+ super().__init__()
650
+ logpy.debug(
651
+ f"constructing {self.__class__.__name__} of depth {depth} w/ "
652
+ f"{in_channels} channels and {n_heads} heads."
653
+ )
654
+
655
+ if exists(context_dim) and not isinstance(context_dim, list):
656
+ context_dim = [context_dim]
657
+ if exists(context_dim) and isinstance(context_dim, list):
658
+ if depth != len(context_dim):
659
+ logpy.warn(
660
+ f"{self.__class__.__name__}: Found context dims "
661
+ f"{context_dim} of depth {len(context_dim)}, which does not "
662
+ f"match the specified 'depth' of {depth}. Setting context_dim "
663
+ f"to {depth * [context_dim[0]]} now."
664
+ )
665
+ # depth does not match context dims.
666
+ assert all(
667
+ map(lambda x: x == context_dim[0], context_dim)
668
+ ), "need homogenous context_dim to match depth automatically"
669
+ context_dim = depth * [context_dim[0]]
670
+ elif context_dim is None:
671
+ context_dim = [None] * depth
672
+ self.in_channels = in_channels
673
+ inner_dim = n_heads * d_head
674
+ self.norm = Normalize(in_channels)
675
+ if not use_linear:
676
+ self.proj_in = nn.Conv2d(
677
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
678
+ )
679
+ else:
680
+ self.proj_in = nn.Linear(in_channels, inner_dim)
681
+
682
+ self.transformer_blocks = nn.ModuleList(
683
+ [
684
+ BasicTransformerBlock(
685
+ inner_dim,
686
+ n_heads,
687
+ d_head,
688
+ dropout=dropout,
689
+ context_dim=context_dim[d],
690
+ disable_self_attn=disable_self_attn,
691
+ attn_mode=attn_type,
692
+ checkpoint=use_checkpoint,
693
+ sdp_backend=sdp_backend,
694
+ )
695
+ for d in range(depth)
696
+ ]
697
+ )
698
+ if not use_linear:
699
+ self.proj_out = zero_module(
700
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
701
+ )
702
+ else:
703
+ # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
704
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
705
+ self.use_linear = use_linear
706
+
707
+ def forward(self, x, context=None):
708
+ # note: if no context is given, cross-attention defaults to self-attention
709
+ if not isinstance(context, list):
710
+ context = [context]
711
+ b, c, h, w = x.shape
712
+ x_in = x
713
+ x = self.norm(x)
714
+ if not self.use_linear:
715
+ x = self.proj_in(x)
716
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
717
+ if self.use_linear:
718
+ x = self.proj_in(x)
719
+ for i, block in enumerate(self.transformer_blocks):
720
+ if i > 0 and len(context) == 1:
721
+ i = 0 # use same context for each block
722
+ x = block(x, context=context[i])
723
+ if self.use_linear:
724
+ x = self.proj_out(x)
725
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
726
+ if not self.use_linear:
727
+ x = self.proj_out(x)
728
+ return x + x_in
729
+
730
+
731
+ class SimpleTransformer(nn.Module):
732
+ def __init__(
733
+ self,
734
+ dim: int,
735
+ depth: int,
736
+ heads: int,
737
+ dim_head: int,
738
+ context_dim: Optional[int] = None,
739
+ dropout: float = 0.0,
740
+ checkpoint: bool = True,
741
+ ):
742
+ super().__init__()
743
+ self.layers = nn.ModuleList([])
744
+ for _ in range(depth):
745
+ self.layers.append(
746
+ BasicTransformerBlock(
747
+ dim,
748
+ heads,
749
+ dim_head,
750
+ dropout=dropout,
751
+ context_dim=context_dim,
752
+ attn_mode="softmax-xformers",
753
+ checkpoint=checkpoint,
754
+ )
755
+ )
756
+
757
+ def forward(
758
+ self,
759
+ x: torch.Tensor,
760
+ context: Optional[torch.Tensor] = None,
761
+ ) -> torch.Tensor:
762
+ for layer in self.layers:
763
+ x = layer(x, context)
764
+ return x
sgm/modules/autoencoding/__init__.py ADDED
File without changes
sgm/modules/autoencoding/losses/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __all__ = [
2
+ "GeneralLPIPSWithDiscriminator",
3
+ "LatentLPIPS",
4
+ ]
5
+
6
+ from .discriminator_loss import GeneralLPIPSWithDiscriminator
7
+ from .lpips import LatentLPIPS
sgm/modules/autoencoding/losses/discriminator_loss.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Iterator, List, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision
7
+ from einops import rearrange
8
+ from matplotlib import colormaps
9
+ from matplotlib import pyplot as plt
10
+
11
+ from ....util import default, instantiate_from_config
12
+ from ..lpips.loss.lpips import LPIPS
13
+ from ..lpips.model.model import weights_init
14
+ from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
15
+
16
+
17
+ class GeneralLPIPSWithDiscriminator(nn.Module):
18
+ def __init__(
19
+ self,
20
+ disc_start: int,
21
+ logvar_init: float = 0.0,
22
+ disc_num_layers: int = 3,
23
+ disc_in_channels: int = 3,
24
+ disc_factor: float = 1.0,
25
+ disc_weight: float = 1.0,
26
+ perceptual_weight: float = 1.0,
27
+ disc_loss: str = "hinge",
28
+ scale_input_to_tgt_size: bool = False,
29
+ dims: int = 2,
30
+ learn_logvar: bool = False,
31
+ regularization_weights: Union[None, Dict[str, float]] = None,
32
+ additional_log_keys: Optional[List[str]] = None,
33
+ discriminator_config: Optional[Dict] = None,
34
+ ):
35
+ super().__init__()
36
+ self.dims = dims
37
+ if self.dims > 2:
38
+ print(
39
+ f"running with dims={dims}. This means that for perceptual loss "
40
+ f"calculation, the LPIPS loss will be applied to each frame "
41
+ f"independently."
42
+ )
43
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
44
+ assert disc_loss in ["hinge", "vanilla"]
45
+ self.perceptual_loss = LPIPS().eval()
46
+ self.perceptual_weight = perceptual_weight
47
+ # output log variance
48
+ self.logvar = nn.Parameter(
49
+ torch.full((), logvar_init), requires_grad=learn_logvar
50
+ )
51
+ self.learn_logvar = learn_logvar
52
+
53
+ discriminator_config = default(
54
+ discriminator_config,
55
+ {
56
+ "target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator",
57
+ "params": {
58
+ "input_nc": disc_in_channels,
59
+ "n_layers": disc_num_layers,
60
+ "use_actnorm": False,
61
+ },
62
+ },
63
+ )
64
+
65
+ self.discriminator = instantiate_from_config(discriminator_config).apply(
66
+ weights_init
67
+ )
68
+ self.discriminator_iter_start = disc_start
69
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
70
+ self.disc_factor = disc_factor
71
+ self.discriminator_weight = disc_weight
72
+ self.regularization_weights = default(regularization_weights, {})
73
+
74
+ self.forward_keys = [
75
+ "optimizer_idx",
76
+ "global_step",
77
+ "last_layer",
78
+ "split",
79
+ "regularization_log",
80
+ ]
81
+
82
+ self.additional_log_keys = set(default(additional_log_keys, []))
83
+ self.additional_log_keys.update(set(self.regularization_weights.keys()))
84
+
85
+ def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
86
+ return self.discriminator.parameters()
87
+
88
+ def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]:
89
+ if self.learn_logvar:
90
+ yield self.logvar
91
+ yield from ()
92
+
93
+ @torch.no_grad()
94
+ def log_images(
95
+ self, inputs: torch.Tensor, reconstructions: torch.Tensor
96
+ ) -> Dict[str, torch.Tensor]:
97
+ # calc logits of real/fake
98
+ logits_real = self.discriminator(inputs.contiguous().detach())
99
+ if len(logits_real.shape) < 4:
100
+ # Non patch-discriminator
101
+ return dict()
102
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
103
+ # -> (b, 1, h, w)
104
+
105
+ # parameters for colormapping
106
+ high = max(logits_fake.abs().max(), logits_real.abs().max()).item()
107
+ cmap = colormaps["PiYG"] # diverging colormap
108
+
109
+ def to_colormap(logits: torch.Tensor) -> torch.Tensor:
110
+ """(b, 1, ...) -> (b, 3, ...)"""
111
+ logits = (logits + high) / (2 * high)
112
+ logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel
113
+ # -> (b, 1, ..., 3)
114
+ logits = torch.from_numpy(logits_np).to(logits.device)
115
+ return rearrange(logits, "b 1 ... c -> b c ...")
116
+
117
+ logits_real = torch.nn.functional.interpolate(
118
+ logits_real,
119
+ size=inputs.shape[-2:],
120
+ mode="nearest",
121
+ antialias=False,
122
+ )
123
+ logits_fake = torch.nn.functional.interpolate(
124
+ logits_fake,
125
+ size=reconstructions.shape[-2:],
126
+ mode="nearest",
127
+ antialias=False,
128
+ )
129
+
130
+ # alpha value of logits for overlay
131
+ alpha_real = torch.abs(logits_real) / high
132
+ alpha_fake = torch.abs(logits_fake) / high
133
+ # -> (b, 1, h, w) in range [0, 0.5]
134
+ # alpha value of lines don't really matter, since the values are the same
135
+ # for both images and logits anyway
136
+ grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4)
137
+ grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4)
138
+ grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1)
139
+ # -> (1, h, w)
140
+ # blend logits and images together
141
+
142
+ # prepare logits for plotting
143
+ logits_real = to_colormap(logits_real)
144
+ logits_fake = to_colormap(logits_fake)
145
+ # resize logits
146
+ # -> (b, 3, h, w)
147
+
148
+ # make some grids
149
+ # add all logits to one plot
150
+ logits_real = torchvision.utils.make_grid(logits_real, nrow=4)
151
+ logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4)
152
+ # I just love how torchvision calls the number of columns `nrow`
153
+ grid_logits = torch.cat((logits_real, logits_fake), dim=1)
154
+ # -> (3, h, w)
155
+
156
+ grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4)
157
+ grid_images_fake = torchvision.utils.make_grid(
158
+ 0.5 * reconstructions + 0.5, nrow=4
159
+ )
160
+ grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1)
161
+ # -> (3, h, w) in range [0, 1]
162
+
163
+ grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images
164
+
165
+ # Create labeled colorbar
166
+ dpi = 100
167
+ height = 128 / dpi
168
+ width = grid_logits.shape[2] / dpi
169
+ fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
170
+ img = ax.imshow(np.array([[-high, high]]), cmap=cmap)
171
+ plt.colorbar(
172
+ img,
173
+ cax=ax,
174
+ orientation="horizontal",
175
+ fraction=0.9,
176
+ aspect=width / height,
177
+ pad=0.0,
178
+ )
179
+ img.set_visible(False)
180
+ fig.tight_layout()
181
+ fig.canvas.draw()
182
+ # manually convert figure to numpy
183
+ cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
184
+ cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,))
185
+ cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0
186
+ cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device)
187
+
188
+ # Add colorbar to plot
189
+ annotated_grid = torch.cat((grid_logits, cbar), dim=1)
190
+ blended_grid = torch.cat((grid_blend, cbar), dim=1)
191
+ return {
192
+ "vis_logits": 2 * annotated_grid[None, ...] - 1,
193
+ "vis_logits_blended": 2 * blended_grid[None, ...] - 1,
194
+ }
195
+
196
+ def calculate_adaptive_weight(
197
+ self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor
198
+ ) -> torch.Tensor:
199
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
200
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
201
+
202
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
203
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
204
+ d_weight = d_weight * self.discriminator_weight
205
+ return d_weight
206
+
207
+ def forward(
208
+ self,
209
+ inputs: torch.Tensor,
210
+ reconstructions: torch.Tensor,
211
+ *, # added because I changed the order here
212
+ regularization_log: Dict[str, torch.Tensor],
213
+ optimizer_idx: int,
214
+ global_step: int,
215
+ last_layer: torch.Tensor,
216
+ split: str = "train",
217
+ weights: Union[None, float, torch.Tensor] = None,
218
+ ) -> Tuple[torch.Tensor, dict]:
219
+ if self.scale_input_to_tgt_size:
220
+ inputs = torch.nn.functional.interpolate(
221
+ inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
222
+ )
223
+
224
+ if self.dims > 2:
225
+ inputs, reconstructions = map(
226
+ lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
227
+ (inputs, reconstructions),
228
+ )
229
+
230
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
231
+ if self.perceptual_weight > 0:
232
+ p_loss = self.perceptual_loss(
233
+ inputs.contiguous(), reconstructions.contiguous()
234
+ )
235
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
236
+
237
+ nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
238
+
239
+ # now the GAN part
240
+ if optimizer_idx == 0:
241
+ # generator update
242
+ if global_step >= self.discriminator_iter_start or not self.training:
243
+ logits_fake = self.discriminator(reconstructions.contiguous())
244
+ g_loss = -torch.mean(logits_fake)
245
+ if self.training:
246
+ d_weight = self.calculate_adaptive_weight(
247
+ nll_loss, g_loss, last_layer=last_layer
248
+ )
249
+ else:
250
+ d_weight = torch.tensor(1.0)
251
+ else:
252
+ d_weight = torch.tensor(0.0)
253
+ g_loss = torch.tensor(0.0, requires_grad=True)
254
+
255
+ loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss
256
+ log = dict()
257
+ for k in regularization_log:
258
+ if k in self.regularization_weights:
259
+ loss = loss + self.regularization_weights[k] * regularization_log[k]
260
+ if k in self.additional_log_keys:
261
+ log[f"{split}/{k}"] = regularization_log[k].detach().float().mean()
262
+
263
+ log.update(
264
+ {
265
+ f"{split}/loss/total": loss.clone().detach().mean(),
266
+ f"{split}/loss/nll": nll_loss.detach().mean(),
267
+ f"{split}/loss/rec": rec_loss.detach().mean(),
268
+ f"{split}/loss/g": g_loss.detach().mean(),
269
+ f"{split}/scalars/logvar": self.logvar.detach(),
270
+ f"{split}/scalars/d_weight": d_weight.detach(),
271
+ }
272
+ )
273
+
274
+ return loss, log
275
+ elif optimizer_idx == 1:
276
+ # second pass for discriminator update
277
+ logits_real = self.discriminator(inputs.contiguous().detach())
278
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
279
+
280
+ if global_step >= self.discriminator_iter_start or not self.training:
281
+ d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake)
282
+ else:
283
+ d_loss = torch.tensor(0.0, requires_grad=True)
284
+
285
+ log = {
286
+ f"{split}/loss/disc": d_loss.clone().detach().mean(),
287
+ f"{split}/logits/real": logits_real.detach().mean(),
288
+ f"{split}/logits/fake": logits_fake.detach().mean(),
289
+ }
290
+ return d_loss, log
291
+ else:
292
+ raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}")
293
+
294
+ def get_nll_loss(
295
+ self,
296
+ rec_loss: torch.Tensor,
297
+ weights: Optional[Union[float, torch.Tensor]] = None,
298
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
299
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
300
+ weighted_nll_loss = nll_loss
301
+ if weights is not None:
302
+ weighted_nll_loss = weights * nll_loss
303
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
304
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
305
+
306
+ return nll_loss, weighted_nll_loss
sgm/modules/autoencoding/losses/lpips.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from ....util import default, instantiate_from_config
5
+ from ..lpips.loss.lpips import LPIPS
6
+
7
+
8
+ class LatentLPIPS(nn.Module):
9
+ def __init__(
10
+ self,
11
+ decoder_config,
12
+ perceptual_weight=1.0,
13
+ latent_weight=1.0,
14
+ scale_input_to_tgt_size=False,
15
+ scale_tgt_to_input_size=False,
16
+ perceptual_weight_on_inputs=0.0,
17
+ ):
18
+ super().__init__()
19
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
20
+ self.scale_tgt_to_input_size = scale_tgt_to_input_size
21
+ self.init_decoder(decoder_config)
22
+ self.perceptual_loss = LPIPS().eval()
23
+ self.perceptual_weight = perceptual_weight
24
+ self.latent_weight = latent_weight
25
+ self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
26
+
27
+ def init_decoder(self, config):
28
+ self.decoder = instantiate_from_config(config)
29
+ if hasattr(self.decoder, "encoder"):
30
+ del self.decoder.encoder
31
+
32
+ def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
33
+ log = dict()
34
+ loss = (latent_inputs - latent_predictions) ** 2
35
+ log[f"{split}/latent_l2_loss"] = loss.mean().detach()
36
+ image_reconstructions = None
37
+ if self.perceptual_weight > 0.0:
38
+ image_reconstructions = self.decoder.decode(latent_predictions)
39
+ image_targets = self.decoder.decode(latent_inputs)
40
+ perceptual_loss = self.perceptual_loss(
41
+ image_targets.contiguous(), image_reconstructions.contiguous()
42
+ )
43
+ loss = (
44
+ self.latent_weight * loss.mean()
45
+ + self.perceptual_weight * perceptual_loss.mean()
46
+ )
47
+ log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
48
+
49
+ if self.perceptual_weight_on_inputs > 0.0:
50
+ image_reconstructions = default(
51
+ image_reconstructions, self.decoder.decode(latent_predictions)
52
+ )
53
+ if self.scale_input_to_tgt_size:
54
+ image_inputs = torch.nn.functional.interpolate(
55
+ image_inputs,
56
+ image_reconstructions.shape[2:],
57
+ mode="bicubic",
58
+ antialias=True,
59
+ )
60
+ elif self.scale_tgt_to_input_size:
61
+ image_reconstructions = torch.nn.functional.interpolate(
62
+ image_reconstructions,
63
+ image_inputs.shape[2:],
64
+ mode="bicubic",
65
+ antialias=True,
66
+ )
67
+
68
+ perceptual_loss2 = self.perceptual_loss(
69
+ image_inputs.contiguous(), image_reconstructions.contiguous()
70
+ )
71
+ loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
72
+ log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
73
+ return loss, log
sgm/modules/autoencoding/lpips/__init__.py ADDED
File without changes
sgm/modules/autoencoding/lpips/loss/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ vgg.pth
sgm/modules/autoencoding/lpips/loss/LICENSE ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are met:
6
+
7
+ * Redistributions of source code must retain the above copyright notice, this
8
+ list of conditions and the following disclaimer.
9
+
10
+ * Redistributions in binary form must reproduce the above copyright notice,
11
+ this list of conditions and the following disclaimer in the documentation
12
+ and/or other materials provided with the distribution.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
sgm/modules/autoencoding/lpips/loss/__init__.py ADDED
File without changes
sgm/modules/autoencoding/lpips/loss/lpips.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2
+
3
+ from collections import namedtuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import models
8
+
9
+ from ..util import get_ckpt_path
10
+
11
+
12
+ class LPIPS(nn.Module):
13
+ # Learned perceptual metric
14
+ def __init__(self, use_dropout=True):
15
+ super().__init__()
16
+ self.scaling_layer = ScalingLayer()
17
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
18
+ self.net = vgg16(pretrained=True, requires_grad=False)
19
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
20
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
21
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
22
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
23
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
24
+ self.load_from_pretrained()
25
+ for param in self.parameters():
26
+ param.requires_grad = False
27
+
28
+ def load_from_pretrained(self, name="vgg_lpips"):
29
+ ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss")
30
+ self.load_state_dict(
31
+ torch.load(ckpt, map_location=torch.device("cpu")), strict=False
32
+ )
33
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
34
+
35
+ @classmethod
36
+ def from_pretrained(cls, name="vgg_lpips"):
37
+ if name != "vgg_lpips":
38
+ raise NotImplementedError
39
+ model = cls()
40
+ ckpt = get_ckpt_path(name)
41
+ model.load_state_dict(
42
+ torch.load(ckpt, map_location=torch.device("cpu")), strict=False
43
+ )
44
+ return model
45
+
46
+ def forward(self, input, target):
47
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
48
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
49
+ feats0, feats1, diffs = {}, {}, {}
50
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
51
+ for kk in range(len(self.chns)):
52
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
53
+ outs1[kk]
54
+ )
55
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
56
+
57
+ res = [
58
+ spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
59
+ for kk in range(len(self.chns))
60
+ ]
61
+ val = res[0]
62
+ for l in range(1, len(self.chns)):
63
+ val += res[l]
64
+ return val
65
+
66
+
67
+ class ScalingLayer(nn.Module):
68
+ def __init__(self):
69
+ super(ScalingLayer, self).__init__()
70
+ self.register_buffer(
71
+ "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
72
+ )
73
+ self.register_buffer(
74
+ "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
75
+ )
76
+
77
+ def forward(self, inp):
78
+ return (inp - self.shift) / self.scale
79
+
80
+
81
+ class NetLinLayer(nn.Module):
82
+ """A single linear layer which does a 1x1 conv"""
83
+
84
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
85
+ super(NetLinLayer, self).__init__()
86
+ layers = (
87
+ [
88
+ nn.Dropout(),
89
+ ]
90
+ if (use_dropout)
91
+ else []
92
+ )
93
+ layers += [
94
+ nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
95
+ ]
96
+ self.model = nn.Sequential(*layers)
97
+
98
+
99
+ class vgg16(torch.nn.Module):
100
+ def __init__(self, requires_grad=False, pretrained=True):
101
+ super(vgg16, self).__init__()
102
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
103
+ self.slice1 = torch.nn.Sequential()
104
+ self.slice2 = torch.nn.Sequential()
105
+ self.slice3 = torch.nn.Sequential()
106
+ self.slice4 = torch.nn.Sequential()
107
+ self.slice5 = torch.nn.Sequential()
108
+ self.N_slices = 5
109
+ for x in range(4):
110
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
111
+ for x in range(4, 9):
112
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
113
+ for x in range(9, 16):
114
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
115
+ for x in range(16, 23):
116
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
117
+ for x in range(23, 30):
118
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
119
+ if not requires_grad:
120
+ for param in self.parameters():
121
+ param.requires_grad = False
122
+
123
+ def forward(self, X):
124
+ h = self.slice1(X)
125
+ h_relu1_2 = h
126
+ h = self.slice2(h)
127
+ h_relu2_2 = h
128
+ h = self.slice3(h)
129
+ h_relu3_3 = h
130
+ h = self.slice4(h)
131
+ h_relu4_3 = h
132
+ h = self.slice5(h)
133
+ h_relu5_3 = h
134
+ vgg_outputs = namedtuple(
135
+ "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
136
+ )
137
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
138
+ return out
139
+
140
+
141
+ def normalize_tensor(x, eps=1e-10):
142
+ norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
143
+ return x / (norm_factor + eps)
144
+
145
+
146
+ def spatial_average(x, keepdim=True):
147
+ return x.mean([2, 3], keepdim=keepdim)
sgm/modules/autoencoding/lpips/model/LICENSE ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are met:
6
+
7
+ * Redistributions of source code must retain the above copyright notice, this
8
+ list of conditions and the following disclaimer.
9
+
10
+ * Redistributions in binary form must reproduce the above copyright notice,
11
+ this list of conditions and the following disclaimer in the documentation
12
+ and/or other materials provided with the distribution.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24
+
25
+
26
+ --------------------------- LICENSE FOR pix2pix --------------------------------
27
+ BSD License
28
+
29
+ For pix2pix software
30
+ Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
31
+ All rights reserved.
32
+
33
+ Redistribution and use in source and binary forms, with or without
34
+ modification, are permitted provided that the following conditions are met:
35
+
36
+ * Redistributions of source code must retain the above copyright notice, this
37
+ list of conditions and the following disclaimer.
38
+
39
+ * Redistributions in binary form must reproduce the above copyright notice,
40
+ this list of conditions and the following disclaimer in the documentation
41
+ and/or other materials provided with the distribution.
42
+
43
+ ----------------------------- LICENSE FOR DCGAN --------------------------------
44
+ BSD License
45
+
46
+ For dcgan.torch software
47
+
48
+ Copyright (c) 2015, Facebook, Inc. All rights reserved.
49
+
50
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
51
+
52
+ Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
53
+
54
+ Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
55
+
56
+ Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
57
+
58
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
sgm/modules/autoencoding/lpips/model/__init__.py ADDED
File without changes
sgm/modules/autoencoding/lpips/model/model.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import torch.nn as nn
4
+
5
+ from ..util import ActNorm
6
+
7
+
8
+ def weights_init(m):
9
+ classname = m.__class__.__name__
10
+ if classname.find("Conv") != -1:
11
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
12
+ elif classname.find("BatchNorm") != -1:
13
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
14
+ nn.init.constant_(m.bias.data, 0)
15
+
16
+
17
+ class NLayerDiscriminator(nn.Module):
18
+ """Defines a PatchGAN discriminator as in Pix2Pix
19
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20
+ """
21
+
22
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
23
+ """Construct a PatchGAN discriminator
24
+ Parameters:
25
+ input_nc (int) -- the number of channels in input images
26
+ ndf (int) -- the number of filters in the last conv layer
27
+ n_layers (int) -- the number of conv layers in the discriminator
28
+ norm_layer -- normalization layer
29
+ """
30
+ super(NLayerDiscriminator, self).__init__()
31
+ if not use_actnorm:
32
+ norm_layer = nn.BatchNorm2d
33
+ else:
34
+ norm_layer = ActNorm
35
+ if (
36
+ type(norm_layer) == functools.partial
37
+ ): # no need to use bias as BatchNorm2d has affine parameters
38
+ use_bias = norm_layer.func != nn.BatchNorm2d
39
+ else:
40
+ use_bias = norm_layer != nn.BatchNorm2d
41
+
42
+ kw = 4
43
+ padw = 1
44
+ sequence = [
45
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
46
+ nn.LeakyReLU(0.2, True),
47
+ ]
48
+ nf_mult = 1
49
+ nf_mult_prev = 1
50
+ for n in range(1, n_layers): # gradually increase the number of filters
51
+ nf_mult_prev = nf_mult
52
+ nf_mult = min(2**n, 8)
53
+ sequence += [
54
+ nn.Conv2d(
55
+ ndf * nf_mult_prev,
56
+ ndf * nf_mult,
57
+ kernel_size=kw,
58
+ stride=2,
59
+ padding=padw,
60
+ bias=use_bias,
61
+ ),
62
+ norm_layer(ndf * nf_mult),
63
+ nn.LeakyReLU(0.2, True),
64
+ ]
65
+
66
+ nf_mult_prev = nf_mult
67
+ nf_mult = min(2**n_layers, 8)
68
+ sequence += [
69
+ nn.Conv2d(
70
+ ndf * nf_mult_prev,
71
+ ndf * nf_mult,
72
+ kernel_size=kw,
73
+ stride=1,
74
+ padding=padw,
75
+ bias=use_bias,
76
+ ),
77
+ norm_layer(ndf * nf_mult),
78
+ nn.LeakyReLU(0.2, True),
79
+ ]
80
+
81
+ sequence += [
82
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
83
+ ] # output 1 channel prediction map
84
+ self.main = nn.Sequential(*sequence)
85
+
86
+ def forward(self, input):
87
+ """Standard forward."""
88
+ return self.main(input)
sgm/modules/autoencoding/lpips/util.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+
4
+ import requests
5
+ import torch
6
+ import torch.nn as nn
7
+ from tqdm import tqdm
8
+
9
+ URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
10
+
11
+ CKPT_MAP = {"vgg_lpips": "vgg.pth"}
12
+
13
+ MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
14
+
15
+
16
+ def download(url, local_path, chunk_size=1024):
17
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
18
+ with requests.get(url, stream=True) as r:
19
+ total_size = int(r.headers.get("content-length", 0))
20
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
21
+ with open(local_path, "wb") as f:
22
+ for data in r.iter_content(chunk_size=chunk_size):
23
+ if data:
24
+ f.write(data)
25
+ pbar.update(chunk_size)
26
+
27
+
28
+ def md5_hash(path):
29
+ with open(path, "rb") as f:
30
+ content = f.read()
31
+ return hashlib.md5(content).hexdigest()
32
+
33
+
34
+ def get_ckpt_path(name, root, check=False):
35
+ assert name in URL_MAP
36
+ path = os.path.join(root, CKPT_MAP[name])
37
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
38
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
39
+ download(URL_MAP[name], path)
40
+ md5 = md5_hash(path)
41
+ assert md5 == MD5_MAP[name], md5
42
+ return path
43
+
44
+
45
+ class ActNorm(nn.Module):
46
+ def __init__(
47
+ self, num_features, logdet=False, affine=True, allow_reverse_init=False
48
+ ):
49
+ assert affine
50
+ super().__init__()
51
+ self.logdet = logdet
52
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
53
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
54
+ self.allow_reverse_init = allow_reverse_init
55
+
56
+ self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
57
+
58
+ def initialize(self, input):
59
+ with torch.no_grad():
60
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
61
+ mean = (
62
+ flatten.mean(1)
63
+ .unsqueeze(1)
64
+ .unsqueeze(2)
65
+ .unsqueeze(3)
66
+ .permute(1, 0, 2, 3)
67
+ )
68
+ std = (
69
+ flatten.std(1)
70
+ .unsqueeze(1)
71
+ .unsqueeze(2)
72
+ .unsqueeze(3)
73
+ .permute(1, 0, 2, 3)
74
+ )
75
+
76
+ self.loc.data.copy_(-mean)
77
+ self.scale.data.copy_(1 / (std + 1e-6))
78
+
79
+ def forward(self, input, reverse=False):
80
+ if reverse:
81
+ return self.reverse(input)
82
+ if len(input.shape) == 2:
83
+ input = input[:, :, None, None]
84
+ squeeze = True
85
+ else:
86
+ squeeze = False
87
+
88
+ _, _, height, width = input.shape
89
+
90
+ if self.training and self.initialized.item() == 0:
91
+ self.initialize(input)
92
+ self.initialized.fill_(1)
93
+
94
+ h = self.scale * (input + self.loc)
95
+
96
+ if squeeze:
97
+ h = h.squeeze(-1).squeeze(-1)
98
+
99
+ if self.logdet:
100
+ log_abs = torch.log(torch.abs(self.scale))
101
+ logdet = height * width * torch.sum(log_abs)
102
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
103
+ return h, logdet
104
+
105
+ return h
106
+
107
+ def reverse(self, output):
108
+ if self.training and self.initialized.item() == 0:
109
+ if not self.allow_reverse_init:
110
+ raise RuntimeError(
111
+ "Initializing ActNorm in reverse direction is "
112
+ "disabled by default. Use allow_reverse_init=True to enable."
113
+ )
114
+ else:
115
+ self.initialize(output)
116
+ self.initialized.fill_(1)
117
+
118
+ if len(output.shape) == 2:
119
+ output = output[:, :, None, None]
120
+ squeeze = True
121
+ else:
122
+ squeeze = False
123
+
124
+ h = output / self.scale - self.loc
125
+
126
+ if squeeze:
127
+ h = h.squeeze(-1).squeeze(-1)
128
+ return h
sgm/modules/autoencoding/lpips/vqperceptual.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def hinge_d_loss(logits_real, logits_fake):
6
+ loss_real = torch.mean(F.relu(1.0 - logits_real))
7
+ loss_fake = torch.mean(F.relu(1.0 + logits_fake))
8
+ d_loss = 0.5 * (loss_real + loss_fake)
9
+ return d_loss
10
+
11
+
12
+ def vanilla_d_loss(logits_real, logits_fake):
13
+ d_loss = 0.5 * (
14
+ torch.mean(torch.nn.functional.softplus(-logits_real))
15
+ + torch.mean(torch.nn.functional.softplus(logits_fake))
16
+ )
17
+ return d_loss
sgm/modules/autoencoding/regularizers/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from typing import Any, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from ....modules.distributions.distributions import \
9
+ DiagonalGaussianDistribution
10
+ from .base import AbstractRegularizer
11
+
12
+
13
+ class DiagonalGaussianRegularizer(AbstractRegularizer):
14
+ def __init__(self, sample: bool = True):
15
+ super().__init__()
16
+ self.sample = sample
17
+
18
+ def get_trainable_parameters(self) -> Any:
19
+ yield from ()
20
+
21
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
22
+ log = dict()
23
+ posterior = DiagonalGaussianDistribution(z)
24
+ if self.sample:
25
+ z = posterior.sample()
26
+ else:
27
+ z = posterior.mode()
28
+ kl_loss = posterior.kl()
29
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
30
+ log["kl_loss"] = kl_loss
31
+ return z, log
sgm/modules/autoencoding/regularizers/base.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from typing import Any, Tuple
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+
9
+ class AbstractRegularizer(nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
14
+ raise NotImplementedError()
15
+
16
+ @abstractmethod
17
+ def get_trainable_parameters(self) -> Any:
18
+ raise NotImplementedError()
19
+
20
+
21
+ class IdentityRegularizer(AbstractRegularizer):
22
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
23
+ return z, dict()
24
+
25
+ def get_trainable_parameters(self) -> Any:
26
+ yield from ()
27
+
28
+
29
+ def measure_perplexity(
30
+ predicted_indices: torch.Tensor, num_centroids: int
31
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
32
+ # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
33
+ # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
34
+ encodings = (
35
+ F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
36
+ )
37
+ avg_probs = encodings.mean(0)
38
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
39
+ cluster_use = torch.sum(avg_probs > 0)
40
+ return perplexity, cluster_use