MaxMilan1
commited on
Commit
·
09339b5
1
Parent(s):
63f29cf
possible working changes for V3D?
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +60 -1
- requirements.txt +44 -7
- scripts/__init__.py +0 -0
- scripts/pub/V3D_512.py +317 -0
- scripts/pub/configs/V3D_512.yaml +161 -0
- scripts/tests/attention.py +319 -0
- scripts/util/__init__.py +0 -0
- scripts/util/detection/__init__.py +0 -0
- scripts/util/detection/nsfw_and_watermark_dectection.py +110 -0
- scripts/util/detection/p_head_v1.npz +3 -0
- scripts/util/detection/w_head_v1.npz +3 -0
- sgm/__init__.py +4 -0
- sgm/data/__init__.py +1 -0
- sgm/data/cam_utils.py +1253 -0
- sgm/data/cifar10.py +67 -0
- sgm/data/co3d.py +1367 -0
- sgm/data/colmap.py +605 -0
- sgm/data/dataset.py +80 -0
- sgm/data/joint3d.py +10 -0
- sgm/data/json_index_dataset.py +1080 -0
- sgm/data/latent_objaverse.py +52 -0
- sgm/data/mnist.py +85 -0
- sgm/data/mvimagenet.py +408 -0
- sgm/data/objaverse.py +882 -0
- sgm/inference/api.py +385 -0
- sgm/inference/helpers.py +305 -0
- sgm/lr_scheduler.py +135 -0
- sgm/models/__init__.py +2 -0
- sgm/models/autoencoder.py +615 -0
- sgm/models/diffusion.py +358 -0
- sgm/models/video3d_diffusion.py +524 -0
- sgm/models/video_diffusion.py +503 -0
- sgm/modules/__init__.py +6 -0
- sgm/modules/attention.py +764 -0
- sgm/modules/autoencoding/__init__.py +0 -0
- sgm/modules/autoencoding/losses/__init__.py +7 -0
- sgm/modules/autoencoding/losses/discriminator_loss.py +306 -0
- sgm/modules/autoencoding/losses/lpips.py +73 -0
- sgm/modules/autoencoding/lpips/__init__.py +0 -0
- sgm/modules/autoencoding/lpips/loss/.gitignore +1 -0
- sgm/modules/autoencoding/lpips/loss/LICENSE +23 -0
- sgm/modules/autoencoding/lpips/loss/__init__.py +0 -0
- sgm/modules/autoencoding/lpips/loss/lpips.py +147 -0
- sgm/modules/autoencoding/lpips/model/LICENSE +58 -0
- sgm/modules/autoencoding/lpips/model/__init__.py +0 -0
- sgm/modules/autoencoding/lpips/model/model.py +88 -0
- sgm/modules/autoencoding/lpips/util.py +128 -0
- sgm/modules/autoencoding/lpips/vqperceptual.py +17 -0
- sgm/modules/autoencoding/regularizers/__init__.py +31 -0
- sgm/modules/autoencoding/regularizers/base.py +40 -0
app.py
CHANGED
@@ -1,5 +1,9 @@
|
|
1 |
import gradio as gr
|
2 |
from util.text_img import generate_image
|
|
|
|
|
|
|
|
|
3 |
|
4 |
_TITLE = "Shoe Generator"
|
5 |
with gr.Blocks(_TITLE) as ShoeGen:
|
@@ -18,6 +22,61 @@ with gr.Blocks(_TITLE) as ShoeGen:
|
|
18 |
button_gen.click(generate_image, inputs=[prompt], outputs=[image, image_nobg])
|
19 |
|
20 |
with gr.Tab("Image to Video Generator (V3D)"):
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
ShoeGen.launch()
|
|
|
1 |
import gradio as gr
|
2 |
from util.text_img import generate_image
|
3 |
+
from util.v3d import generate_v3d, prep
|
4 |
+
|
5 |
+
# Prepare the V3D model
|
6 |
+
model, clip_model, ae_model, device, num_frames, num_steps, rembg_session, output_folder = prep()
|
7 |
|
8 |
_TITLE = "Shoe Generator"
|
9 |
with gr.Blocks(_TITLE) as ShoeGen:
|
|
|
22 |
button_gen.click(generate_image, inputs=[prompt], outputs=[image, image_nobg])
|
23 |
|
24 |
with gr.Tab("Image to Video Generator (V3D)"):
|
25 |
+
with gr.Row(equal_height=True):
|
26 |
+
with gr.Column():
|
27 |
+
input_image = gr.Image(value=None, label="Input Image")
|
28 |
+
|
29 |
+
border_ratio_slider = gr.Slider(
|
30 |
+
value=0.3,
|
31 |
+
label="Border Ratio",
|
32 |
+
minimum=0.05,
|
33 |
+
maximum=0.5,
|
34 |
+
step=0.05,
|
35 |
+
)
|
36 |
+
decoding_t_slider = gr.Slider(
|
37 |
+
value=1,
|
38 |
+
label="Number of Decoding frames",
|
39 |
+
minimum=1,
|
40 |
+
maximum=num_frames,
|
41 |
+
step=1,
|
42 |
+
)
|
43 |
+
min_guidance_slider = gr.Slider(
|
44 |
+
value=3.5,
|
45 |
+
label="Min CFG Value",
|
46 |
+
minimum=0.05,
|
47 |
+
maximum=0.5,
|
48 |
+
step=0.05,
|
49 |
+
)
|
50 |
+
max_guidance_slider = gr.Slider(
|
51 |
+
value=3.5,
|
52 |
+
label="Max CFG Value",
|
53 |
+
minimum=0.05,
|
54 |
+
maximum=0.5,
|
55 |
+
step=0.05,
|
56 |
+
)
|
57 |
+
run_button = gr.Button(value="Run V3D")
|
58 |
+
|
59 |
+
with gr.Column():
|
60 |
+
output_video = gr.Video(value=None, label="Output Orbit Video")
|
61 |
+
|
62 |
+
run_button.click(generate_v3d,
|
63 |
+
inputs=[
|
64 |
+
input_image,
|
65 |
+
model,
|
66 |
+
clip_model,
|
67 |
+
ae_model,
|
68 |
+
num_frames,
|
69 |
+
num_steps,
|
70 |
+
int(decoding_t_slider),
|
71 |
+
border_ratio_slider,
|
72 |
+
False,
|
73 |
+
rembg_session,
|
74 |
+
output_folder,
|
75 |
+
min_guidance_slider,
|
76 |
+
max_guidance_slider,
|
77 |
+
device,
|
78 |
+
],
|
79 |
+
outputs=[output_video],
|
80 |
+
)
|
81 |
|
82 |
ShoeGen.launch()
|
requirements.txt
CHANGED
@@ -1,12 +1,49 @@
|
|
1 |
-
torch
|
2 |
gradio
|
3 |
diffusers==0.26.3
|
4 |
-
transformers==4.38.1
|
5 |
accelerate==0.27.2
|
6 |
-
xformers
|
7 |
rembg
|
8 |
-
Pillow
|
9 |
Python-IO
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
gradio
|
2 |
diffusers==0.26.3
|
|
|
3 |
accelerate==0.27.2
|
|
|
4 |
rembg
|
|
|
5 |
Python-IO
|
6 |
+
huggingface-hub
|
7 |
+
black==23.7.0
|
8 |
+
chardet==5.1.0
|
9 |
+
clip @ git+https://github.com/openai/CLIP.git
|
10 |
+
einops>=0.6.1
|
11 |
+
fairscale>=0.4.13
|
12 |
+
fire>=0.5.0
|
13 |
+
fsspec>=2023.6.0
|
14 |
+
invisible-watermark>=0.2.0
|
15 |
+
kornia==0.6.9
|
16 |
+
matplotlib>=3.7.2
|
17 |
+
natsort>=8.4.0
|
18 |
+
ninja>=1.11.1
|
19 |
+
numpy>=1.24.4
|
20 |
+
omegaconf>=2.3.0
|
21 |
+
open-clip-torch>=2.20.0
|
22 |
+
opencv-python==4.6.0.66
|
23 |
+
pandas>=2.0.3
|
24 |
+
pillow>=9.5.0
|
25 |
+
pudb>=2022.1.3
|
26 |
+
pytorch-lightning==2.0.1
|
27 |
+
pyyaml>=6.0.1
|
28 |
+
scipy>=1.10.1
|
29 |
+
streamlit>=0.73.1
|
30 |
+
tensorboardx==2.6
|
31 |
+
timm>=0.9.2
|
32 |
+
tokenizers==0.12.1
|
33 |
+
torch>=2.0.1
|
34 |
+
torchaudio>=2.0.2
|
35 |
+
torchdata==0.6.1
|
36 |
+
torchmetrics>=1.0.1
|
37 |
+
torchvision>=0.15.2
|
38 |
+
tqdm>=4.65.0
|
39 |
+
transformers==4.19.1
|
40 |
+
triton==2.0.0
|
41 |
+
urllib3<1.27,>=1.25.4
|
42 |
+
wandb>=0.15.6
|
43 |
+
webdataset>=0.2.33
|
44 |
+
wheel>=0.41.0
|
45 |
+
xformers>=0.0.20
|
46 |
+
streamlit-keyup==0.2.0
|
47 |
+
mediapy
|
48 |
+
tyro
|
49 |
+
wget
|
scripts/__init__.py
ADDED
File without changes
|
scripts/pub/V3D_512.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
from glob import glob
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
from fire import Fire
|
12 |
+
import tyro
|
13 |
+
from omegaconf import OmegaConf
|
14 |
+
from PIL import Image
|
15 |
+
from torchvision.transforms import ToTensor
|
16 |
+
from mediapy import write_video
|
17 |
+
import rembg
|
18 |
+
from kiui.op import recenter
|
19 |
+
from safetensors.torch import load_file as load_safetensors
|
20 |
+
from typing import Any
|
21 |
+
|
22 |
+
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
|
23 |
+
from sgm.inference.helpers import embed_watermark
|
24 |
+
from sgm.util import default, instantiate_from_config
|
25 |
+
|
26 |
+
|
27 |
+
def get_unique_embedder_keys_from_conditioner(conditioner):
|
28 |
+
return list(set([x.input_key for x in conditioner.embedders]))
|
29 |
+
|
30 |
+
|
31 |
+
def get_batch(keys, value_dict, N, T, device):
|
32 |
+
batch = {}
|
33 |
+
batch_uc = {}
|
34 |
+
|
35 |
+
for key in keys:
|
36 |
+
if key == "fps_id":
|
37 |
+
batch[key] = (
|
38 |
+
torch.tensor([value_dict["fps_id"]])
|
39 |
+
.to(device)
|
40 |
+
.repeat(int(math.prod(N)))
|
41 |
+
)
|
42 |
+
elif key == "motion_bucket_id":
|
43 |
+
batch[key] = (
|
44 |
+
torch.tensor([value_dict["motion_bucket_id"]])
|
45 |
+
.to(device)
|
46 |
+
.repeat(int(math.prod(N)))
|
47 |
+
)
|
48 |
+
elif key == "cond_aug":
|
49 |
+
batch[key] = repeat(
|
50 |
+
torch.tensor([value_dict["cond_aug"]]).to(device),
|
51 |
+
"1 -> b",
|
52 |
+
b=math.prod(N),
|
53 |
+
)
|
54 |
+
elif key == "cond_frames":
|
55 |
+
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
|
56 |
+
elif key == "cond_frames_without_noise":
|
57 |
+
batch[key] = repeat(
|
58 |
+
value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
|
59 |
+
)
|
60 |
+
else:
|
61 |
+
batch[key] = value_dict[key]
|
62 |
+
|
63 |
+
if T is not None:
|
64 |
+
batch["num_video_frames"] = T
|
65 |
+
|
66 |
+
for key in batch.keys():
|
67 |
+
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
68 |
+
batch_uc[key] = torch.clone(batch[key])
|
69 |
+
return batch, batch_uc
|
70 |
+
|
71 |
+
|
72 |
+
def load_model(
|
73 |
+
config: str,
|
74 |
+
device: str,
|
75 |
+
num_frames: int,
|
76 |
+
num_steps: int,
|
77 |
+
ckpt_path: Optional[str] = None,
|
78 |
+
min_cfg: Optional[float] = None,
|
79 |
+
max_cfg: Optional[float] = None,
|
80 |
+
sigma_max: Optional[float] = None,
|
81 |
+
):
|
82 |
+
config = OmegaConf.load(config)
|
83 |
+
|
84 |
+
config.model.params.sampler_config.params.num_steps = num_steps
|
85 |
+
config.model.params.sampler_config.params.guider_config.params.num_frames = (
|
86 |
+
num_frames
|
87 |
+
)
|
88 |
+
if max_cfg is not None:
|
89 |
+
config.model.params.sampler_config.params.guider_config.params.max_scale = (
|
90 |
+
max_cfg
|
91 |
+
)
|
92 |
+
if min_cfg is not None:
|
93 |
+
config.model.params.sampler_config.params.guider_config.params.min_scale = (
|
94 |
+
min_cfg
|
95 |
+
)
|
96 |
+
if sigma_max is not None:
|
97 |
+
print("Overriding sigma_max to ", sigma_max)
|
98 |
+
config.model.params.sampler_config.params.discretization_config.params.sigma_max = (
|
99 |
+
sigma_max
|
100 |
+
)
|
101 |
+
|
102 |
+
config.model.params.from_scratch = False
|
103 |
+
|
104 |
+
if ckpt_path is not None:
|
105 |
+
config.model.params.ckpt_path = str(ckpt_path)
|
106 |
+
if device == "cuda":
|
107 |
+
with torch.device(device):
|
108 |
+
model = instantiate_from_config(config.model).to(device).eval()
|
109 |
+
else:
|
110 |
+
model = instantiate_from_config(config.model).to(device).eval()
|
111 |
+
|
112 |
+
return model, None
|
113 |
+
|
114 |
+
|
115 |
+
def sample_one(
|
116 |
+
input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
|
117 |
+
checkpoint_path: Optional[str] = None,
|
118 |
+
num_frames: Optional[int] = None,
|
119 |
+
num_steps: Optional[int] = None,
|
120 |
+
fps_id: int = 1,
|
121 |
+
motion_bucket_id: int = 300,
|
122 |
+
cond_aug: float = 0.02,
|
123 |
+
seed: int = 23,
|
124 |
+
decoding_t: int = 24, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
|
125 |
+
device: str = "cuda",
|
126 |
+
output_folder: Optional[str] = None,
|
127 |
+
noise: torch.Tensor = None,
|
128 |
+
save: bool = False,
|
129 |
+
cached_model: Any = None,
|
130 |
+
border_ratio: float = 0.3,
|
131 |
+
min_guidance_scale: float = 3.5,
|
132 |
+
max_guidance_scale: float = 3.5,
|
133 |
+
sigma_max: float = None,
|
134 |
+
ignore_alpha: bool = False,
|
135 |
+
):
|
136 |
+
model_config = "scripts/pub/configs/V3D_512.yaml"
|
137 |
+
num_frames = OmegaConf.load(
|
138 |
+
model_config
|
139 |
+
).model.params.sampler_config.params.guider_config.params.num_frames
|
140 |
+
print("Detected num_frames:", num_frames)
|
141 |
+
num_steps = default(num_steps, 25)
|
142 |
+
output_folder = default(output_folder, f"outputs/V3D_512")
|
143 |
+
decoding_t = min(decoding_t, num_frames)
|
144 |
+
|
145 |
+
sd = load_safetensors("./ckpts/svd_xt.safetensors")
|
146 |
+
clip_model_config = OmegaConf.load("configs/embedder/clip_image.yaml")
|
147 |
+
clip_model = instantiate_from_config(clip_model_config).eval()
|
148 |
+
clip_sd = dict()
|
149 |
+
for k, v in sd.items():
|
150 |
+
if "conditioner.embedders.0" in k:
|
151 |
+
clip_sd[k.replace("conditioner.embedders.0.", "")] = v
|
152 |
+
clip_model.load_state_dict(clip_sd)
|
153 |
+
clip_model = clip_model.to(device)
|
154 |
+
|
155 |
+
ae_model_config = OmegaConf.load("configs/ae/video.yaml")
|
156 |
+
ae_model = instantiate_from_config(ae_model_config).eval()
|
157 |
+
encoder_sd = dict()
|
158 |
+
for k, v in sd.items():
|
159 |
+
if "first_stage_model" in k:
|
160 |
+
encoder_sd[k.replace("first_stage_model.", "")] = v
|
161 |
+
ae_model.load_state_dict(encoder_sd)
|
162 |
+
ae_model = ae_model.to(device)
|
163 |
+
|
164 |
+
if cached_model is None:
|
165 |
+
model, filter = load_model(
|
166 |
+
model_config,
|
167 |
+
device,
|
168 |
+
num_frames,
|
169 |
+
num_steps,
|
170 |
+
ckpt_path=checkpoint_path,
|
171 |
+
min_cfg=min_guidance_scale,
|
172 |
+
max_cfg=max_guidance_scale,
|
173 |
+
sigma_max=sigma_max,
|
174 |
+
)
|
175 |
+
else:
|
176 |
+
model = cached_model
|
177 |
+
torch.manual_seed(seed)
|
178 |
+
|
179 |
+
need_return = True
|
180 |
+
path = Path(input_path)
|
181 |
+
if path.is_file():
|
182 |
+
if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
|
183 |
+
all_img_paths = [input_path]
|
184 |
+
else:
|
185 |
+
raise ValueError("Path is not valid image file.")
|
186 |
+
elif path.is_dir():
|
187 |
+
all_img_paths = sorted(
|
188 |
+
[
|
189 |
+
f
|
190 |
+
for f in path.iterdir()
|
191 |
+
if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
|
192 |
+
]
|
193 |
+
)
|
194 |
+
need_return = False
|
195 |
+
if len(all_img_paths) == 0:
|
196 |
+
raise ValueError("Folder does not contain any images.")
|
197 |
+
else:
|
198 |
+
raise ValueError
|
199 |
+
|
200 |
+
for input_path in all_img_paths:
|
201 |
+
with Image.open(input_path) as image:
|
202 |
+
# if image.mode == "RGBA":
|
203 |
+
# image = image.convert("RGB")
|
204 |
+
w, h = image.size
|
205 |
+
|
206 |
+
if border_ratio > 0:
|
207 |
+
if image.mode != "RGBA" or ignore_alpha:
|
208 |
+
image = image.convert("RGB")
|
209 |
+
image = np.asarray(image)
|
210 |
+
carved_image = rembg.remove(image) # [H, W, 4]
|
211 |
+
else:
|
212 |
+
image = np.asarray(image)
|
213 |
+
carved_image = image
|
214 |
+
mask = carved_image[..., -1] > 0
|
215 |
+
image = recenter(carved_image, mask, border_ratio=border_ratio)
|
216 |
+
image = image.astype(np.float32) / 255.0
|
217 |
+
if image.shape[-1] == 4:
|
218 |
+
image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
|
219 |
+
image = Image.fromarray((image * 255).astype(np.uint8))
|
220 |
+
else:
|
221 |
+
print("Ignore border ratio")
|
222 |
+
image = image.resize((512, 512))
|
223 |
+
|
224 |
+
image = ToTensor()(image)
|
225 |
+
image = image * 2.0 - 1.0
|
226 |
+
|
227 |
+
image = image.unsqueeze(0).to(device)
|
228 |
+
H, W = image.shape[2:]
|
229 |
+
assert image.shape[1] == 3
|
230 |
+
F = 8
|
231 |
+
C = 4
|
232 |
+
shape = (num_frames, C, H // F, W // F)
|
233 |
+
|
234 |
+
value_dict = {}
|
235 |
+
value_dict["motion_bucket_id"] = motion_bucket_id
|
236 |
+
value_dict["fps_id"] = fps_id
|
237 |
+
value_dict["cond_aug"] = cond_aug
|
238 |
+
value_dict["cond_frames_without_noise"] = clip_model(image)
|
239 |
+
value_dict["cond_frames"] = ae_model.encode(image)
|
240 |
+
value_dict["cond_frames"] += cond_aug * torch.randn_like(
|
241 |
+
value_dict["cond_frames"]
|
242 |
+
)
|
243 |
+
value_dict["cond_aug"] = cond_aug
|
244 |
+
|
245 |
+
with torch.no_grad():
|
246 |
+
with torch.autocast(device):
|
247 |
+
batch, batch_uc = get_batch(
|
248 |
+
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
249 |
+
value_dict,
|
250 |
+
[1, num_frames],
|
251 |
+
T=num_frames,
|
252 |
+
device=device,
|
253 |
+
)
|
254 |
+
c, uc = model.conditioner.get_unconditional_conditioning(
|
255 |
+
batch,
|
256 |
+
batch_uc=batch_uc,
|
257 |
+
force_uc_zero_embeddings=[
|
258 |
+
"cond_frames",
|
259 |
+
"cond_frames_without_noise",
|
260 |
+
],
|
261 |
+
)
|
262 |
+
|
263 |
+
for k in ["crossattn", "concat"]:
|
264 |
+
uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
|
265 |
+
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
|
266 |
+
c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
|
267 |
+
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
|
268 |
+
|
269 |
+
randn = torch.randn(shape, device=device) if noise is None else noise
|
270 |
+
randn = randn.to(device)
|
271 |
+
|
272 |
+
additional_model_inputs = {}
|
273 |
+
additional_model_inputs["image_only_indicator"] = torch.zeros(
|
274 |
+
2, num_frames
|
275 |
+
).to(device)
|
276 |
+
additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
|
277 |
+
|
278 |
+
def denoiser(input, sigma, c):
|
279 |
+
return model.denoiser(
|
280 |
+
model.model, input, sigma, c, **additional_model_inputs
|
281 |
+
)
|
282 |
+
|
283 |
+
samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
|
284 |
+
model.en_and_decode_n_samples_a_time = decoding_t
|
285 |
+
samples_x = model.decode_first_stage(samples_z)
|
286 |
+
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
287 |
+
|
288 |
+
os.makedirs(output_folder, exist_ok=True)
|
289 |
+
base_count = len(glob(os.path.join(output_folder, "*.mp4")))
|
290 |
+
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
|
291 |
+
# writer = cv2.VideoWriter(
|
292 |
+
# video_path,
|
293 |
+
# cv2.VideoWriter_fourcc(*"MP4V"),
|
294 |
+
# fps_id + 1,
|
295 |
+
# (samples.shape[-1], samples.shape[-2]),
|
296 |
+
# )
|
297 |
+
|
298 |
+
frames = (
|
299 |
+
(rearrange(samples, "t c h w -> t h w c") * 255)
|
300 |
+
.cpu()
|
301 |
+
.numpy()
|
302 |
+
.astype(np.uint8)
|
303 |
+
)
|
304 |
+
|
305 |
+
if save:
|
306 |
+
write_video(video_path, frames, fps=3)
|
307 |
+
|
308 |
+
images = []
|
309 |
+
for frame in frames:
|
310 |
+
images.append(Image.fromarray(frame))
|
311 |
+
|
312 |
+
if need_return:
|
313 |
+
return images, model
|
314 |
+
|
315 |
+
|
316 |
+
if __name__ == "__main__":
|
317 |
+
tyro.cli(sample_one)
|
scripts/pub/configs/V3D_512.yaml
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-04
|
3 |
+
target: sgm.models.video_diffusion.DiffusionEngine
|
4 |
+
params:
|
5 |
+
ckpt_path: ckpts/V3D_512.ckpt
|
6 |
+
scale_factor: 0.18215
|
7 |
+
disable_first_stage_autocast: true
|
8 |
+
input_key: latents
|
9 |
+
log_keys: []
|
10 |
+
scheduler_config:
|
11 |
+
target: sgm.lr_scheduler.LambdaLinearScheduler
|
12 |
+
params:
|
13 |
+
warm_up_steps:
|
14 |
+
- 1
|
15 |
+
cycle_lengths:
|
16 |
+
- 10000000000000
|
17 |
+
f_start:
|
18 |
+
- 1.0e-06
|
19 |
+
f_max:
|
20 |
+
- 1.0
|
21 |
+
f_min:
|
22 |
+
- 1.0
|
23 |
+
denoiser_config:
|
24 |
+
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
25 |
+
params:
|
26 |
+
scaling_config:
|
27 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
28 |
+
network_config:
|
29 |
+
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
30 |
+
params:
|
31 |
+
adm_in_channels: 768
|
32 |
+
num_classes: sequential
|
33 |
+
use_checkpoint: true
|
34 |
+
in_channels: 8
|
35 |
+
out_channels: 4
|
36 |
+
model_channels: 320
|
37 |
+
attention_resolutions:
|
38 |
+
- 4
|
39 |
+
- 2
|
40 |
+
- 1
|
41 |
+
num_res_blocks: 2
|
42 |
+
channel_mult:
|
43 |
+
- 1
|
44 |
+
- 2
|
45 |
+
- 4
|
46 |
+
- 4
|
47 |
+
num_head_channels: 64
|
48 |
+
use_linear_in_transformer: true
|
49 |
+
transformer_depth: 1
|
50 |
+
context_dim: 1024
|
51 |
+
spatial_transformer_attn_type: softmax-xformers
|
52 |
+
extra_ff_mix_layer: true
|
53 |
+
use_spatial_context: true
|
54 |
+
merge_strategy: learned_with_images
|
55 |
+
video_kernel_size:
|
56 |
+
- 3
|
57 |
+
- 1
|
58 |
+
- 1
|
59 |
+
conditioner_config:
|
60 |
+
target: sgm.modules.GeneralConditioner
|
61 |
+
params:
|
62 |
+
emb_models:
|
63 |
+
- is_trainable: false
|
64 |
+
ucg_rate: 0.2
|
65 |
+
input_key: cond_frames_without_noise
|
66 |
+
target: sgm.modules.encoders.modules.IdentityEncoder
|
67 |
+
- input_key: fps_id
|
68 |
+
is_trainable: true
|
69 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
70 |
+
params:
|
71 |
+
outdim: 256
|
72 |
+
- input_key: motion_bucket_id
|
73 |
+
is_trainable: true
|
74 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
75 |
+
params:
|
76 |
+
outdim: 256
|
77 |
+
- input_key: cond_frames
|
78 |
+
is_trainable: false
|
79 |
+
ucg_rate: 0.2
|
80 |
+
target: sgm.modules.encoders.modules.IdentityEncoder
|
81 |
+
- input_key: cond_aug
|
82 |
+
is_trainable: true
|
83 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
84 |
+
params:
|
85 |
+
outdim: 256
|
86 |
+
first_stage_config:
|
87 |
+
target: sgm.models.autoencoder.AutoencodingEngine
|
88 |
+
params:
|
89 |
+
loss_config:
|
90 |
+
target: torch.nn.Identity
|
91 |
+
regularizer_config:
|
92 |
+
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
93 |
+
encoder_config:
|
94 |
+
target: sgm.modules.diffusionmodules.model.Encoder
|
95 |
+
params:
|
96 |
+
attn_type: vanilla
|
97 |
+
double_z: true
|
98 |
+
z_channels: 4
|
99 |
+
resolution: 256
|
100 |
+
in_channels: 3
|
101 |
+
out_ch: 3
|
102 |
+
ch: 128
|
103 |
+
ch_mult:
|
104 |
+
- 1
|
105 |
+
- 2
|
106 |
+
- 4
|
107 |
+
- 4
|
108 |
+
num_res_blocks: 2
|
109 |
+
attn_resolutions: []
|
110 |
+
dropout: 0.0
|
111 |
+
decoder_config:
|
112 |
+
target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
|
113 |
+
params:
|
114 |
+
attn_type: vanilla
|
115 |
+
double_z: true
|
116 |
+
z_channels: 4
|
117 |
+
resolution: 256
|
118 |
+
in_channels: 3
|
119 |
+
out_ch: 3
|
120 |
+
ch: 128
|
121 |
+
ch_mult:
|
122 |
+
- 1
|
123 |
+
- 2
|
124 |
+
- 4
|
125 |
+
- 4
|
126 |
+
num_res_blocks: 2
|
127 |
+
attn_resolutions: []
|
128 |
+
dropout: 0.0
|
129 |
+
video_kernel_size:
|
130 |
+
- 3
|
131 |
+
- 1
|
132 |
+
- 1
|
133 |
+
sampler_config:
|
134 |
+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
135 |
+
params:
|
136 |
+
num_steps: 30
|
137 |
+
discretization_config:
|
138 |
+
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
139 |
+
params:
|
140 |
+
sigma_max: 700.0
|
141 |
+
guider_config:
|
142 |
+
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
|
143 |
+
params:
|
144 |
+
max_scale: 3.5
|
145 |
+
min_scale: 3.5
|
146 |
+
num_frames: 18
|
147 |
+
loss_fn_config:
|
148 |
+
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
149 |
+
params:
|
150 |
+
batch2model_keys:
|
151 |
+
- num_video_frames
|
152 |
+
- image_only_indicator
|
153 |
+
loss_weighting_config:
|
154 |
+
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
155 |
+
params:
|
156 |
+
sigma_data: 1.0
|
157 |
+
sigma_sampler_config:
|
158 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
159 |
+
params:
|
160 |
+
p_mean: 1.5
|
161 |
+
p_std: 2.0
|
scripts/tests/attention.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import einops
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.utils.benchmark as benchmark
|
5 |
+
from torch.backends.cuda import SDPBackend
|
6 |
+
|
7 |
+
from sgm.modules.attention import BasicTransformerBlock, SpatialTransformer
|
8 |
+
|
9 |
+
|
10 |
+
def benchmark_attn():
|
11 |
+
# Lets define a helpful benchmarking function:
|
12 |
+
# https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
|
13 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
+
|
15 |
+
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
|
16 |
+
t0 = benchmark.Timer(
|
17 |
+
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
|
18 |
+
)
|
19 |
+
return t0.blocked_autorange().mean * 1e6
|
20 |
+
|
21 |
+
# Lets define the hyper-parameters of our input
|
22 |
+
batch_size = 32
|
23 |
+
max_sequence_len = 1024
|
24 |
+
num_heads = 32
|
25 |
+
embed_dimension = 32
|
26 |
+
|
27 |
+
dtype = torch.float16
|
28 |
+
|
29 |
+
query = torch.rand(
|
30 |
+
batch_size,
|
31 |
+
num_heads,
|
32 |
+
max_sequence_len,
|
33 |
+
embed_dimension,
|
34 |
+
device=device,
|
35 |
+
dtype=dtype,
|
36 |
+
)
|
37 |
+
key = torch.rand(
|
38 |
+
batch_size,
|
39 |
+
num_heads,
|
40 |
+
max_sequence_len,
|
41 |
+
embed_dimension,
|
42 |
+
device=device,
|
43 |
+
dtype=dtype,
|
44 |
+
)
|
45 |
+
value = torch.rand(
|
46 |
+
batch_size,
|
47 |
+
num_heads,
|
48 |
+
max_sequence_len,
|
49 |
+
embed_dimension,
|
50 |
+
device=device,
|
51 |
+
dtype=dtype,
|
52 |
+
)
|
53 |
+
|
54 |
+
print(f"q/k/v shape:", query.shape, key.shape, value.shape)
|
55 |
+
|
56 |
+
# Lets explore the speed of each of the 3 implementations
|
57 |
+
from torch.backends.cuda import SDPBackend, sdp_kernel
|
58 |
+
|
59 |
+
# Helpful arguments mapper
|
60 |
+
backend_map = {
|
61 |
+
SDPBackend.MATH: {
|
62 |
+
"enable_math": True,
|
63 |
+
"enable_flash": False,
|
64 |
+
"enable_mem_efficient": False,
|
65 |
+
},
|
66 |
+
SDPBackend.FLASH_ATTENTION: {
|
67 |
+
"enable_math": False,
|
68 |
+
"enable_flash": True,
|
69 |
+
"enable_mem_efficient": False,
|
70 |
+
},
|
71 |
+
SDPBackend.EFFICIENT_ATTENTION: {
|
72 |
+
"enable_math": False,
|
73 |
+
"enable_flash": False,
|
74 |
+
"enable_mem_efficient": True,
|
75 |
+
},
|
76 |
+
}
|
77 |
+
|
78 |
+
from torch.profiler import ProfilerActivity, profile, record_function
|
79 |
+
|
80 |
+
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
|
81 |
+
|
82 |
+
print(
|
83 |
+
f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
84 |
+
)
|
85 |
+
with profile(
|
86 |
+
activities=activities, record_shapes=False, profile_memory=True
|
87 |
+
) as prof:
|
88 |
+
with record_function("Default detailed stats"):
|
89 |
+
for _ in range(25):
|
90 |
+
o = F.scaled_dot_product_attention(query, key, value)
|
91 |
+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
92 |
+
|
93 |
+
print(
|
94 |
+
f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
95 |
+
)
|
96 |
+
with sdp_kernel(**backend_map[SDPBackend.MATH]):
|
97 |
+
with profile(
|
98 |
+
activities=activities, record_shapes=False, profile_memory=True
|
99 |
+
) as prof:
|
100 |
+
with record_function("Math implmentation stats"):
|
101 |
+
for _ in range(25):
|
102 |
+
o = F.scaled_dot_product_attention(query, key, value)
|
103 |
+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
104 |
+
|
105 |
+
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
|
106 |
+
try:
|
107 |
+
print(
|
108 |
+
f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
109 |
+
)
|
110 |
+
except RuntimeError:
|
111 |
+
print("FlashAttention is not supported. See warnings for reasons.")
|
112 |
+
with profile(
|
113 |
+
activities=activities, record_shapes=False, profile_memory=True
|
114 |
+
) as prof:
|
115 |
+
with record_function("FlashAttention stats"):
|
116 |
+
for _ in range(25):
|
117 |
+
o = F.scaled_dot_product_attention(query, key, value)
|
118 |
+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
119 |
+
|
120 |
+
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
|
121 |
+
try:
|
122 |
+
print(
|
123 |
+
f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
124 |
+
)
|
125 |
+
except RuntimeError:
|
126 |
+
print("EfficientAttention is not supported. See warnings for reasons.")
|
127 |
+
with profile(
|
128 |
+
activities=activities, record_shapes=False, profile_memory=True
|
129 |
+
) as prof:
|
130 |
+
with record_function("EfficientAttention stats"):
|
131 |
+
for _ in range(25):
|
132 |
+
o = F.scaled_dot_product_attention(query, key, value)
|
133 |
+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
134 |
+
|
135 |
+
|
136 |
+
def run_model(model, x, context):
|
137 |
+
return model(x, context)
|
138 |
+
|
139 |
+
|
140 |
+
def benchmark_transformer_blocks():
|
141 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
142 |
+
import torch.utils.benchmark as benchmark
|
143 |
+
|
144 |
+
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
|
145 |
+
t0 = benchmark.Timer(
|
146 |
+
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
|
147 |
+
)
|
148 |
+
return t0.blocked_autorange().mean * 1e6
|
149 |
+
|
150 |
+
checkpoint = True
|
151 |
+
compile = False
|
152 |
+
|
153 |
+
batch_size = 32
|
154 |
+
h, w = 64, 64
|
155 |
+
context_len = 77
|
156 |
+
embed_dimension = 1024
|
157 |
+
context_dim = 1024
|
158 |
+
d_head = 64
|
159 |
+
|
160 |
+
transformer_depth = 4
|
161 |
+
|
162 |
+
n_heads = embed_dimension // d_head
|
163 |
+
|
164 |
+
dtype = torch.float16
|
165 |
+
|
166 |
+
model_native = SpatialTransformer(
|
167 |
+
embed_dimension,
|
168 |
+
n_heads,
|
169 |
+
d_head,
|
170 |
+
context_dim=context_dim,
|
171 |
+
use_linear=True,
|
172 |
+
use_checkpoint=checkpoint,
|
173 |
+
attn_type="softmax",
|
174 |
+
depth=transformer_depth,
|
175 |
+
sdp_backend=SDPBackend.FLASH_ATTENTION,
|
176 |
+
).to(device)
|
177 |
+
model_efficient_attn = SpatialTransformer(
|
178 |
+
embed_dimension,
|
179 |
+
n_heads,
|
180 |
+
d_head,
|
181 |
+
context_dim=context_dim,
|
182 |
+
use_linear=True,
|
183 |
+
depth=transformer_depth,
|
184 |
+
use_checkpoint=checkpoint,
|
185 |
+
attn_type="softmax-xformers",
|
186 |
+
).to(device)
|
187 |
+
if not checkpoint and compile:
|
188 |
+
print("compiling models")
|
189 |
+
model_native = torch.compile(model_native)
|
190 |
+
model_efficient_attn = torch.compile(model_efficient_attn)
|
191 |
+
|
192 |
+
x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
|
193 |
+
c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)
|
194 |
+
|
195 |
+
from torch.profiler import ProfilerActivity, profile, record_function
|
196 |
+
|
197 |
+
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
|
198 |
+
|
199 |
+
with torch.autocast("cuda"):
|
200 |
+
print(
|
201 |
+
f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
|
202 |
+
)
|
203 |
+
print(
|
204 |
+
f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
|
205 |
+
)
|
206 |
+
|
207 |
+
print(75 * "+")
|
208 |
+
print("NATIVE")
|
209 |
+
print(75 * "+")
|
210 |
+
torch.cuda.reset_peak_memory_stats()
|
211 |
+
with profile(
|
212 |
+
activities=activities, record_shapes=False, profile_memory=True
|
213 |
+
) as prof:
|
214 |
+
with record_function("NativeAttention stats"):
|
215 |
+
for _ in range(25):
|
216 |
+
model_native(x, c)
|
217 |
+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
218 |
+
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")
|
219 |
+
|
220 |
+
print(75 * "+")
|
221 |
+
print("Xformers")
|
222 |
+
print(75 * "+")
|
223 |
+
torch.cuda.reset_peak_memory_stats()
|
224 |
+
with profile(
|
225 |
+
activities=activities, record_shapes=False, profile_memory=True
|
226 |
+
) as prof:
|
227 |
+
with record_function("xformers stats"):
|
228 |
+
for _ in range(25):
|
229 |
+
model_efficient_attn(x, c)
|
230 |
+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
231 |
+
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")
|
232 |
+
|
233 |
+
|
234 |
+
def test01():
|
235 |
+
# conv1x1 vs linear
|
236 |
+
from sgm.util import count_params
|
237 |
+
|
238 |
+
conv = torch.nn.Conv2d(3, 32, kernel_size=1).cuda()
|
239 |
+
print(count_params(conv))
|
240 |
+
linear = torch.nn.Linear(3, 32).cuda()
|
241 |
+
print(count_params(linear))
|
242 |
+
|
243 |
+
print(conv.weight.shape)
|
244 |
+
|
245 |
+
# use same initialization
|
246 |
+
linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
|
247 |
+
linear.bias = torch.nn.Parameter(conv.bias)
|
248 |
+
|
249 |
+
print(linear.weight.shape)
|
250 |
+
|
251 |
+
x = torch.randn(11, 3, 64, 64).cuda()
|
252 |
+
|
253 |
+
xr = einops.rearrange(x, "b c h w -> b (h w) c").contiguous()
|
254 |
+
print(xr.shape)
|
255 |
+
out_linear = linear(xr)
|
256 |
+
print(out_linear.mean(), out_linear.shape)
|
257 |
+
|
258 |
+
out_conv = conv(x)
|
259 |
+
print(out_conv.mean(), out_conv.shape)
|
260 |
+
print("done with test01.\n")
|
261 |
+
|
262 |
+
|
263 |
+
def test02():
|
264 |
+
# try cosine flash attention
|
265 |
+
import time
|
266 |
+
|
267 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
268 |
+
torch.backends.cudnn.allow_tf32 = True
|
269 |
+
torch.backends.cudnn.benchmark = True
|
270 |
+
print("testing cosine flash attention...")
|
271 |
+
DIM = 1024
|
272 |
+
SEQLEN = 4096
|
273 |
+
BS = 16
|
274 |
+
|
275 |
+
print(" softmax (vanilla) first...")
|
276 |
+
model = BasicTransformerBlock(
|
277 |
+
dim=DIM,
|
278 |
+
n_heads=16,
|
279 |
+
d_head=64,
|
280 |
+
dropout=0.0,
|
281 |
+
context_dim=None,
|
282 |
+
attn_mode="softmax",
|
283 |
+
).cuda()
|
284 |
+
try:
|
285 |
+
x = torch.randn(BS, SEQLEN, DIM).cuda()
|
286 |
+
tic = time.time()
|
287 |
+
y = model(x)
|
288 |
+
toc = time.time()
|
289 |
+
print(y.shape, toc - tic)
|
290 |
+
except RuntimeError as e:
|
291 |
+
# likely oom
|
292 |
+
print(str(e))
|
293 |
+
|
294 |
+
print("\n now flash-cosine...")
|
295 |
+
model = BasicTransformerBlock(
|
296 |
+
dim=DIM,
|
297 |
+
n_heads=16,
|
298 |
+
d_head=64,
|
299 |
+
dropout=0.0,
|
300 |
+
context_dim=None,
|
301 |
+
attn_mode="flash-cosine",
|
302 |
+
).cuda()
|
303 |
+
x = torch.randn(BS, SEQLEN, DIM).cuda()
|
304 |
+
tic = time.time()
|
305 |
+
y = model(x)
|
306 |
+
toc = time.time()
|
307 |
+
print(y.shape, toc - tic)
|
308 |
+
print("done with test02.\n")
|
309 |
+
|
310 |
+
|
311 |
+
if __name__ == "__main__":
|
312 |
+
# test01()
|
313 |
+
# test02()
|
314 |
+
# test03()
|
315 |
+
|
316 |
+
# benchmark_attn()
|
317 |
+
benchmark_transformer_blocks()
|
318 |
+
|
319 |
+
print("done.")
|
scripts/util/__init__.py
ADDED
File without changes
|
scripts/util/detection/__init__.py
ADDED
File without changes
|
scripts/util/detection/nsfw_and_watermark_dectection.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import clip
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torchvision.transforms as T
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
RESOURCES_ROOT = "scripts/util/detection/"
|
10 |
+
|
11 |
+
|
12 |
+
def predict_proba(X, weights, biases):
|
13 |
+
logits = X @ weights.T + biases
|
14 |
+
proba = np.where(
|
15 |
+
logits >= 0, 1 / (1 + np.exp(-logits)), np.exp(logits) / (1 + np.exp(logits))
|
16 |
+
)
|
17 |
+
return proba.T
|
18 |
+
|
19 |
+
|
20 |
+
def load_model_weights(path: str):
|
21 |
+
model_weights = np.load(path)
|
22 |
+
return model_weights["weights"], model_weights["biases"]
|
23 |
+
|
24 |
+
|
25 |
+
def clip_process_images(images: torch.Tensor) -> torch.Tensor:
|
26 |
+
min_size = min(images.shape[-2:])
|
27 |
+
return T.Compose(
|
28 |
+
[
|
29 |
+
T.CenterCrop(min_size), # TODO: this might affect the watermark, check this
|
30 |
+
T.Resize(224, interpolation=T.InterpolationMode.BICUBIC, antialias=True),
|
31 |
+
T.Normalize(
|
32 |
+
(0.48145466, 0.4578275, 0.40821073),
|
33 |
+
(0.26862954, 0.26130258, 0.27577711),
|
34 |
+
),
|
35 |
+
]
|
36 |
+
)(images)
|
37 |
+
|
38 |
+
|
39 |
+
class DeepFloydDataFiltering(object):
|
40 |
+
def __init__(
|
41 |
+
self, verbose: bool = False, device: torch.device = torch.device("cpu")
|
42 |
+
):
|
43 |
+
super().__init__()
|
44 |
+
self.verbose = verbose
|
45 |
+
self._device = None
|
46 |
+
self.clip_model, _ = clip.load("ViT-L/14", device=device)
|
47 |
+
self.clip_model.eval()
|
48 |
+
|
49 |
+
self.cpu_w_weights, self.cpu_w_biases = load_model_weights(
|
50 |
+
os.path.join(RESOURCES_ROOT, "w_head_v1.npz")
|
51 |
+
)
|
52 |
+
self.cpu_p_weights, self.cpu_p_biases = load_model_weights(
|
53 |
+
os.path.join(RESOURCES_ROOT, "p_head_v1.npz")
|
54 |
+
)
|
55 |
+
self.w_threshold, self.p_threshold = 0.5, 0.5
|
56 |
+
|
57 |
+
@torch.inference_mode()
|
58 |
+
def __call__(self, images: torch.Tensor) -> torch.Tensor:
|
59 |
+
imgs = clip_process_images(images)
|
60 |
+
if self._device is None:
|
61 |
+
self._device = next(p for p in self.clip_model.parameters()).device
|
62 |
+
image_features = self.clip_model.encode_image(imgs.to(self._device))
|
63 |
+
image_features = image_features.detach().cpu().numpy().astype(np.float16)
|
64 |
+
p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases)
|
65 |
+
w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases)
|
66 |
+
print(f"p_pred = {p_pred}, w_pred = {w_pred}") if self.verbose else None
|
67 |
+
query = p_pred > self.p_threshold
|
68 |
+
if query.sum() > 0:
|
69 |
+
print(f"Hit for p_threshold: {p_pred}") if self.verbose else None
|
70 |
+
images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query])
|
71 |
+
query = w_pred > self.w_threshold
|
72 |
+
if query.sum() > 0:
|
73 |
+
print(f"Hit for w_threshold: {w_pred}") if self.verbose else None
|
74 |
+
images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query])
|
75 |
+
return images
|
76 |
+
|
77 |
+
|
78 |
+
def load_img(path: str) -> torch.Tensor:
|
79 |
+
image = Image.open(path)
|
80 |
+
if not image.mode == "RGB":
|
81 |
+
image = image.convert("RGB")
|
82 |
+
image_transforms = T.Compose(
|
83 |
+
[
|
84 |
+
T.ToTensor(),
|
85 |
+
]
|
86 |
+
)
|
87 |
+
return image_transforms(image)[None, ...]
|
88 |
+
|
89 |
+
|
90 |
+
def test(root):
|
91 |
+
from einops import rearrange
|
92 |
+
|
93 |
+
filter = DeepFloydDataFiltering(verbose=True)
|
94 |
+
for p in os.listdir((root)):
|
95 |
+
print(f"running on {p}...")
|
96 |
+
img = load_img(os.path.join(root, p))
|
97 |
+
filtered_img = filter(img)
|
98 |
+
filtered_img = rearrange(
|
99 |
+
255.0 * (filtered_img.numpy())[0], "c h w -> h w c"
|
100 |
+
).astype(np.uint8)
|
101 |
+
Image.fromarray(filtered_img).save(
|
102 |
+
os.path.join(root, f"{os.path.splitext(p)[0]}-filtered.jpg")
|
103 |
+
)
|
104 |
+
|
105 |
+
|
106 |
+
if __name__ == "__main__":
|
107 |
+
import fire
|
108 |
+
|
109 |
+
fire.Fire(test)
|
110 |
+
print("done.")
|
scripts/util/detection/p_head_v1.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b4653a64d5f85d8d4c5f6c5ec175f1c5c5e37db8f38d39b2ed8b5979da7fdc76
|
3 |
+
size 3588
|
scripts/util/detection/w_head_v1.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b6af23687aa347073e692025f405ccc48c14aadc5dbe775b3312041006d496d1
|
3 |
+
size 3588
|
sgm/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .models import AutoencodingEngine, DiffusionEngine
|
2 |
+
from .util import get_configs_path, instantiate_from_config
|
3 |
+
|
4 |
+
__version__ = "0.1.0"
|
sgm/data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .dataset import StableDataModuleFromConfig
|
sgm/data/cam_utils.py
ADDED
@@ -0,0 +1,1253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Common camera utilities
|
3 |
+
'''
|
4 |
+
|
5 |
+
import math
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from pytorch3d.renderer import PerspectiveCameras
|
10 |
+
from pytorch3d.renderer.cameras import look_at_view_transform
|
11 |
+
from pytorch3d.renderer.implicit.raysampling import _xy_to_ray_bundle
|
12 |
+
|
13 |
+
class RelativeCameraLoader(nn.Module):
|
14 |
+
def __init__(self,
|
15 |
+
query_batch_size=1,
|
16 |
+
rand_query=True,
|
17 |
+
relative=True,
|
18 |
+
center_at_origin=False,
|
19 |
+
):
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
self.query_batch_size = query_batch_size
|
23 |
+
self.rand_query = rand_query
|
24 |
+
self.relative = relative
|
25 |
+
self.center_at_origin = center_at_origin
|
26 |
+
|
27 |
+
def plot_cameras(self, cameras_1, cameras_2):
|
28 |
+
'''
|
29 |
+
Helper function to plot cameras
|
30 |
+
|
31 |
+
Args:
|
32 |
+
cameras_1 (PyTorch3D camera): cameras object to plot
|
33 |
+
cameras_2 (PyTorch3D camera): cameras object to plot
|
34 |
+
'''
|
35 |
+
from pytorch3d.vis.plotly_vis import AxisArgs, plot_batch_individually, plot_scene
|
36 |
+
import plotly.graph_objects as go
|
37 |
+
plotlyplot = plot_scene(
|
38 |
+
{
|
39 |
+
'scene_batch': {
|
40 |
+
'cameras': cameras_1.to('cpu'),
|
41 |
+
'rel_cameras': cameras_2.to('cpu'),
|
42 |
+
}
|
43 |
+
},
|
44 |
+
camera_scale=.5,#0.05,
|
45 |
+
pointcloud_max_points=10000,
|
46 |
+
pointcloud_marker_size=1.0,
|
47 |
+
raybundle_max_rays=100
|
48 |
+
)
|
49 |
+
plotlyplot.show()
|
50 |
+
|
51 |
+
def concat_cameras(self, camera_list):
|
52 |
+
'''
|
53 |
+
Returns a concatenation of a list of cameras
|
54 |
+
|
55 |
+
Args:
|
56 |
+
camera_list (List[PyTorch3D camera]): a list of PyTorch3D cameras
|
57 |
+
'''
|
58 |
+
R_list, T_list, f_list, c_list, size_list = [], [], [], [], []
|
59 |
+
for cameras in camera_list:
|
60 |
+
R_list.append(cameras.R)
|
61 |
+
T_list.append(cameras.T)
|
62 |
+
f_list.append(cameras.focal_length)
|
63 |
+
c_list.append(cameras.principal_point)
|
64 |
+
size_list.append(cameras.image_size)
|
65 |
+
|
66 |
+
camera_slice = PerspectiveCameras(
|
67 |
+
R = torch.cat(R_list),
|
68 |
+
T = torch.cat(T_list),
|
69 |
+
focal_length = torch.cat(f_list),
|
70 |
+
principal_point = torch.cat(c_list),
|
71 |
+
image_size = torch.cat(size_list),
|
72 |
+
device = camera_list[0].device,
|
73 |
+
)
|
74 |
+
return camera_slice
|
75 |
+
|
76 |
+
def get_camera_slice(self, scene_cameras, indices):
|
77 |
+
'''
|
78 |
+
Return a subset of cameras from a super set given indices
|
79 |
+
|
80 |
+
Args:
|
81 |
+
scene_cameras (PyTorch3D Camera): cameras object
|
82 |
+
indices (tensor or List): a flat list or tensor of indices
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
camera_slice (PyTorch3D Camera) - cameras subset
|
86 |
+
'''
|
87 |
+
camera_slice = PerspectiveCameras(
|
88 |
+
R = scene_cameras.R[indices],
|
89 |
+
T = scene_cameras.T[indices],
|
90 |
+
focal_length = scene_cameras.focal_length[indices],
|
91 |
+
principal_point = scene_cameras.principal_point[indices],
|
92 |
+
image_size = scene_cameras.image_size[indices],
|
93 |
+
device = scene_cameras.device,
|
94 |
+
)
|
95 |
+
return camera_slice
|
96 |
+
|
97 |
+
|
98 |
+
def get_relative_camera(self, scene_cameras:PerspectiveCameras, query_idx, center_at_origin=False):
|
99 |
+
"""
|
100 |
+
Transform context cameras relative to a base query camera
|
101 |
+
|
102 |
+
Args:
|
103 |
+
scene_cameras (PyTorch3D Camera): cameras object
|
104 |
+
query_idx (tensor or List): a length 1 list defining query idx
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
cams_relative (PyTorch3D Camera): cameras object relative to query camera
|
108 |
+
"""
|
109 |
+
|
110 |
+
query_camera = self.get_camera_slice(scene_cameras, query_idx)
|
111 |
+
query_world2view = query_camera.get_world_to_view_transform()
|
112 |
+
all_world2view = scene_cameras.get_world_to_view_transform()
|
113 |
+
|
114 |
+
if center_at_origin:
|
115 |
+
identity_cam = PerspectiveCameras(device=scene_cameras.device, R=query_camera.R, T=query_camera.T)
|
116 |
+
else:
|
117 |
+
T = torch.zeros((1, 3))
|
118 |
+
identity_cam = PerspectiveCameras(device=scene_cameras.device, R=query_camera.R, T=T)
|
119 |
+
|
120 |
+
identity_world2view = identity_cam.get_world_to_view_transform()
|
121 |
+
|
122 |
+
# compose the relative transformation as g_i^{-1} g_j
|
123 |
+
relative_world2view = identity_world2view.inverse().compose(all_world2view)
|
124 |
+
|
125 |
+
# generate a camera from the relative transform
|
126 |
+
relative_matrix = relative_world2view.get_matrix()
|
127 |
+
cams_relative = PerspectiveCameras(
|
128 |
+
R = relative_matrix[:, :3, :3],
|
129 |
+
T = relative_matrix[:, 3, :3],
|
130 |
+
focal_length = scene_cameras.focal_length,
|
131 |
+
principal_point = scene_cameras.principal_point,
|
132 |
+
image_size = scene_cameras.image_size,
|
133 |
+
device = scene_cameras.device,
|
134 |
+
)
|
135 |
+
return cams_relative
|
136 |
+
|
137 |
+
def forward(self, scene_cameras, scene_rgb=None, scene_masks=None, query_idx=None, context_size=3, context_idx=None, return_context=False):
|
138 |
+
'''
|
139 |
+
Return a sampled batch of query and context cameras (used in training)
|
140 |
+
|
141 |
+
Args:
|
142 |
+
scene_cameras (PyTorch3D Camera): a batch of PyTorch3D cameras
|
143 |
+
scene_rgb (Tensor): a batch of rgb
|
144 |
+
scene_masks (Tensor): a batch of masks (optional)
|
145 |
+
query_idx (List or Tensor): desired query idx (optional)
|
146 |
+
context_size (int): number of views for context
|
147 |
+
|
148 |
+
Returns:
|
149 |
+
query_cameras, query_rgb, query_masks: random query view
|
150 |
+
context_cameras, context_rgb, context_masks: context views
|
151 |
+
'''
|
152 |
+
|
153 |
+
if query_idx is None:
|
154 |
+
query_idx = [0]
|
155 |
+
if self.rand_query:
|
156 |
+
rand = torch.randperm(len(scene_cameras))
|
157 |
+
query_idx = rand[:1]
|
158 |
+
|
159 |
+
if context_idx is None:
|
160 |
+
rand = torch.randperm(len(scene_cameras))
|
161 |
+
context_idx = rand[:context_size]
|
162 |
+
|
163 |
+
|
164 |
+
if self.relative:
|
165 |
+
rel_cameras = self.get_relative_camera(scene_cameras, query_idx, center_at_origin=self.center_at_origin)
|
166 |
+
else:
|
167 |
+
rel_cameras = scene_cameras
|
168 |
+
|
169 |
+
query_cameras = self.get_camera_slice(rel_cameras, query_idx)
|
170 |
+
query_rgb = None
|
171 |
+
if scene_rgb is not None:
|
172 |
+
query_rgb = scene_rgb[query_idx]
|
173 |
+
query_masks = None
|
174 |
+
if scene_masks is not None:
|
175 |
+
query_masks = scene_masks[query_idx]
|
176 |
+
|
177 |
+
context_cameras = self.get_camera_slice(rel_cameras, context_idx)
|
178 |
+
context_rgb = None
|
179 |
+
if scene_rgb is not None:
|
180 |
+
context_rgb = scene_rgb[context_idx]
|
181 |
+
context_masks = None
|
182 |
+
if scene_masks is not None:
|
183 |
+
context_masks = scene_masks[context_idx]
|
184 |
+
|
185 |
+
if return_context:
|
186 |
+
return query_cameras, query_rgb, query_masks, context_cameras, context_rgb, context_masks, context_idx
|
187 |
+
return query_cameras, query_rgb, query_masks, context_cameras, context_rgb, context_masks
|
188 |
+
|
189 |
+
|
190 |
+
def get_interpolated_path(cameras: PerspectiveCameras, n=50, method='circle', theta_offset_max=0.0):
|
191 |
+
'''
|
192 |
+
Given a camera object containing a set of cameras, fit a circle and get
|
193 |
+
interpolated cameras
|
194 |
+
|
195 |
+
Args:
|
196 |
+
cameras (PyTorch3D Camera): input camera object
|
197 |
+
n (int): length of cameras in new path
|
198 |
+
method (str): 'circle'
|
199 |
+
theta_offset_max (int): max camera jitter in radians
|
200 |
+
|
201 |
+
Returns:
|
202 |
+
path_cameras (PyTorch3D Camera): interpolated cameras
|
203 |
+
'''
|
204 |
+
device = cameras.device
|
205 |
+
cameras = cameras.cpu()
|
206 |
+
|
207 |
+
if method == 'circle':
|
208 |
+
|
209 |
+
#@ https://meshlogic.github.io/posts/jupyter/curve-fitting/fitting-a-circle-to-cluster-of-3d-points/
|
210 |
+
#@ Fit plane
|
211 |
+
P = cameras.get_camera_center().cpu()
|
212 |
+
P_mean = P.mean(axis=0)
|
213 |
+
P_centered = P - P_mean
|
214 |
+
U,s,V = torch.linalg.svd(P_centered)
|
215 |
+
normal = V[2,:]
|
216 |
+
if (normal*2 - P_mean).norm() < (normal - P_mean).norm():
|
217 |
+
normal = - normal
|
218 |
+
d = -torch.dot(P_mean, normal) # d = -<p,n>
|
219 |
+
|
220 |
+
#@ Project pts to plane
|
221 |
+
P_xy = rodrigues_rot(P_centered, normal, torch.tensor([0.0,0.0,1.0]))
|
222 |
+
|
223 |
+
#@ Fit circle in 2D
|
224 |
+
xc, yc, r = fit_circle_2d(P_xy[:,0], P_xy[:,1])
|
225 |
+
t = torch.linspace(0, 2*math.pi, 100)
|
226 |
+
xx = xc + r*torch.cos(t)
|
227 |
+
yy = yc + r*torch.sin(t)
|
228 |
+
|
229 |
+
#@ Project circle to 3D
|
230 |
+
C = rodrigues_rot(torch.tensor([xc,yc,0.0]), torch.tensor([0.0,0.0,1.0]), normal) + P_mean
|
231 |
+
C = C.flatten()
|
232 |
+
|
233 |
+
#@ Get pts n 3D
|
234 |
+
t = torch.linspace(0, 2*math.pi, n)
|
235 |
+
u = P[0] - C
|
236 |
+
new_camera_centers = generate_circle_by_vectors(t, C, r, normal, u)
|
237 |
+
|
238 |
+
#@ OPTIONAL THETA OFFSET
|
239 |
+
if theta_offset_max > 0.0:
|
240 |
+
aug_theta = (torch.rand((new_camera_centers.shape[0])) * (2*theta_offset_max)) - theta_offset_max
|
241 |
+
new_camera_centers = rodrigues_rot2(new_camera_centers, normal, aug_theta)
|
242 |
+
|
243 |
+
#@ Get camera look at
|
244 |
+
new_camera_look_at = get_nearest_centroid(cameras)
|
245 |
+
|
246 |
+
#@ Get R T
|
247 |
+
up_vec = -normal
|
248 |
+
R, T = look_at_view_transform(eye=new_camera_centers, at=new_camera_look_at.unsqueeze(0), up=up_vec.unsqueeze(0), device=cameras.device)
|
249 |
+
else:
|
250 |
+
raise NotImplementedError
|
251 |
+
|
252 |
+
c = (cameras.principal_point).mean(dim=0, keepdim=True).expand(R.shape[0],-1)
|
253 |
+
f = (cameras.focal_length).mean(dim=0, keepdim=True).expand(R.shape[0],-1)
|
254 |
+
image_size = cameras.image_size[:1].expand(R.shape[0],-1)
|
255 |
+
|
256 |
+
|
257 |
+
path_cameras = PerspectiveCameras(R=R,T=T,focal_length=f,principal_point=c,image_size=image_size, device=device)
|
258 |
+
cameras = cameras.to(device)
|
259 |
+
return path_cameras
|
260 |
+
|
261 |
+
def np_normalize(vec, axis=-1):
|
262 |
+
vec = vec / (np.linalg.norm(vec, axis=axis, keepdims=True) + 1e-9)
|
263 |
+
return vec
|
264 |
+
|
265 |
+
|
266 |
+
#@ https://meshlogic.github.io/posts/jupyter/curve-fitting/fitting-a-circle-to-cluster-of-3d-points/
|
267 |
+
#-------------------------------------------------------------------------------
|
268 |
+
# Generate points on circle
|
269 |
+
# P(t) = r*cos(t)*u + r*sin(t)*(n x u) + C
|
270 |
+
#-------------------------------------------------------------------------------
|
271 |
+
def generate_circle_by_vectors(t, C, r, n, u):
|
272 |
+
n = n/torch.linalg.norm(n)
|
273 |
+
u = u/torch.linalg.norm(u)
|
274 |
+
P_circle = r*torch.cos(t)[:,None]*u + r*torch.sin(t)[:,None]*torch.cross(n,u) + C
|
275 |
+
return P_circle
|
276 |
+
|
277 |
+
#@ https://meshlogic.github.io/posts/jupyter/curve-fitting/fitting-a-circle-to-cluster-of-3d-points/
|
278 |
+
#-------------------------------------------------------------------------------
|
279 |
+
# FIT CIRCLE 2D
|
280 |
+
# - Find center [xc, yc] and radius r of circle fitting to set of 2D points
|
281 |
+
# - Optionally specify weights for points
|
282 |
+
#
|
283 |
+
# - Implicit circle function:
|
284 |
+
# (x-xc)^2 + (y-yc)^2 = r^2
|
285 |
+
# (2*xc)*x + (2*yc)*y + (r^2-xc^2-yc^2) = x^2+y^2
|
286 |
+
# c[0]*x + c[1]*y + c[2] = x^2+y^2
|
287 |
+
#
|
288 |
+
# - Solution by method of least squares:
|
289 |
+
# A*c = b, c' = argmin(||A*c - b||^2)
|
290 |
+
# A = [x y 1], b = [x^2+y^2]
|
291 |
+
#-------------------------------------------------------------------------------
|
292 |
+
def fit_circle_2d(x, y, w=[]):
|
293 |
+
|
294 |
+
A = torch.stack([x, y, torch.ones(len(x))]).T
|
295 |
+
b = x**2 + y**2
|
296 |
+
|
297 |
+
# Modify A,b for weighted least squares
|
298 |
+
if len(w) == len(x):
|
299 |
+
W = torch.diag(w)
|
300 |
+
A = torch.dot(W,A)
|
301 |
+
b = torch.dot(W,b)
|
302 |
+
|
303 |
+
# Solve by method of least squares
|
304 |
+
c = torch.linalg.lstsq(A,b,rcond=None)[0]
|
305 |
+
|
306 |
+
# Get circle parameters from solution c
|
307 |
+
xc = c[0]/2
|
308 |
+
yc = c[1]/2
|
309 |
+
r = torch.sqrt(c[2] + xc**2 + yc**2)
|
310 |
+
return xc, yc, r
|
311 |
+
|
312 |
+
#@ https://meshlogic.github.io/posts/jupyter/curve-fitting/fitting-a-circle-to-cluster-of-3d-points/
|
313 |
+
#-------------------------------------------------------------------------------
|
314 |
+
# RODRIGUES ROTATION
|
315 |
+
# - Rotate given points based on a starting and ending vector
|
316 |
+
# - Axis k and angle of rotation theta given by vectors n0,n1
|
317 |
+
# P_rot = P*cos(theta) + (k x P)*sin(theta) + k*<k,P>*(1-cos(theta))
|
318 |
+
#-------------------------------------------------------------------------------
|
319 |
+
def rodrigues_rot(P, n0, n1):
|
320 |
+
|
321 |
+
# If P is only 1d array (coords of single point), fix it to be matrix
|
322 |
+
if P.ndim == 1:
|
323 |
+
P = P[None,...]
|
324 |
+
|
325 |
+
# Get vector of rotation k and angle theta
|
326 |
+
n0 = n0/torch.linalg.norm(n0)
|
327 |
+
n1 = n1/torch.linalg.norm(n1)
|
328 |
+
k = torch.cross(n0,n1)
|
329 |
+
k = k/torch.linalg.norm(k)
|
330 |
+
theta = torch.arccos(torch.dot(n0,n1))
|
331 |
+
|
332 |
+
# Compute rotated points
|
333 |
+
P_rot = torch.zeros((len(P),3))
|
334 |
+
for i in range(len(P)):
|
335 |
+
P_rot[i] = P[i]*torch.cos(theta) + torch.cross(k,P[i])*torch.sin(theta) + k*torch.dot(k,P[i])*(1-torch.cos(theta))
|
336 |
+
|
337 |
+
return P_rot
|
338 |
+
|
339 |
+
def rodrigues_rot2(P, n1, theta):
|
340 |
+
'''
|
341 |
+
Rotate points P wrt axis k by theta radians
|
342 |
+
'''
|
343 |
+
|
344 |
+
# If P is only 1d array (coords of single point), fix it to be matrix
|
345 |
+
if P.ndim == 1:
|
346 |
+
P = P[None,...]
|
347 |
+
|
348 |
+
k = torch.cross(P, n1.unsqueeze(0))
|
349 |
+
k = k/torch.linalg.norm(k)
|
350 |
+
|
351 |
+
# Compute rotated points
|
352 |
+
P_rot = torch.zeros((len(P),3))
|
353 |
+
for i in range(len(P)):
|
354 |
+
P_rot[i] = P[i]*torch.cos(theta[i]) + torch.cross(k[i],P[i])*torch.sin(theta[i]) + k[i]*torch.dot(k[i],P[i])*(1-torch.cos(theta[i]))
|
355 |
+
|
356 |
+
return P_rot
|
357 |
+
|
358 |
+
#@ https://meshlogic.github.io/posts/jupyter/curve-fitting/fitting-a-circle-to-cluster-of-3d-points/
|
359 |
+
#-------------------------------------------------------------------------------
|
360 |
+
# ANGLE BETWEEN
|
361 |
+
# - Get angle between vectors u,v with sign based on plane with unit normal n
|
362 |
+
#-------------------------------------------------------------------------------
|
363 |
+
def angle_between(u, v, n=None):
|
364 |
+
if n is None:
|
365 |
+
return torch.arctan2(torch.linalg.norm(torch.cross(u,v)), torch.dot(u,v))
|
366 |
+
else:
|
367 |
+
return torch.arctan2(torch.dot(n,torch.cross(u,v)), torch.dot(u,v))
|
368 |
+
|
369 |
+
#@ https://www.crewes.org/Documents/ResearchReports/2010/CRR201032.pdf
|
370 |
+
def get_nearest_centroid(cameras: PerspectiveCameras):
|
371 |
+
'''
|
372 |
+
Given PyTorch3D cameras, find the nearest point along their principal ray
|
373 |
+
'''
|
374 |
+
|
375 |
+
#@ GET CAMERA CENTERS AND DIRECTIONS
|
376 |
+
camera_centers = cameras.get_camera_center()
|
377 |
+
|
378 |
+
c_mean = (cameras.principal_point).mean(dim=0)
|
379 |
+
xy_grid = c_mean.unsqueeze(0).unsqueeze(0)
|
380 |
+
ray_vis = _xy_to_ray_bundle(cameras, xy_grid.expand(len(cameras),-1,-1), 1.0, 15.0, 20, True)
|
381 |
+
camera_directions = ray_vis.directions
|
382 |
+
|
383 |
+
#@ CONSTRUCT MATRICIES
|
384 |
+
A = torch.zeros((3*len(cameras)), len(cameras)+3)
|
385 |
+
b = torch.zeros((3*len(cameras), 1))
|
386 |
+
A[:,:3] = torch.eye(3).repeat(len(cameras),1)
|
387 |
+
for ci in range(len(camera_directions)):
|
388 |
+
A[3*ci:3*ci+3, ci+3] = -camera_directions[ci]
|
389 |
+
b[3*ci:3*ci+3, 0] = camera_centers[ci]
|
390 |
+
#' A (3*N, 3*N+3) b (3*N, 1)
|
391 |
+
|
392 |
+
#@ SVD
|
393 |
+
U, s, VT = torch.linalg.svd(A)
|
394 |
+
Sinv = torch.diag(1/s)
|
395 |
+
if len(s) < 3*len(cameras):
|
396 |
+
Sinv = torch.cat((Sinv, torch.zeros((Sinv.shape[0], 3*len(cameras) - Sinv.shape[1]), device=Sinv.device)), dim=1)
|
397 |
+
x = torch.matmul(VT.T, torch.matmul(Sinv,torch.matmul(U.T, b)))
|
398 |
+
|
399 |
+
centroid = x[:3,0]
|
400 |
+
return centroid
|
401 |
+
|
402 |
+
|
403 |
+
def get_angles(target_camera: PerspectiveCameras, context_cameras: PerspectiveCameras, centroid=None):
|
404 |
+
'''
|
405 |
+
Get angles between cameras wrt a centroid
|
406 |
+
|
407 |
+
Args:
|
408 |
+
target_camera (Pytorch3D Camera): a camera object with a single camera
|
409 |
+
context_cameras (PyTorch3D Camera): a camera object
|
410 |
+
|
411 |
+
Returns:
|
412 |
+
theta_deg (Tensor): a tensor containing angles in degrees
|
413 |
+
'''
|
414 |
+
a1 = target_camera.get_camera_center()
|
415 |
+
b1 = context_cameras.get_camera_center()
|
416 |
+
|
417 |
+
a = a1 - centroid.unsqueeze(0)
|
418 |
+
a = a.expand(len(context_cameras), -1)
|
419 |
+
b = b1 - centroid.unsqueeze(0)
|
420 |
+
|
421 |
+
ab_dot = (a*b).sum(dim=-1)
|
422 |
+
theta = torch.acos((ab_dot)/(torch.linalg.norm(a, dim=-1) * torch.linalg.norm(b, dim=-1)))
|
423 |
+
theta_deg = theta * 180 / math.pi
|
424 |
+
|
425 |
+
return theta_deg
|
426 |
+
|
427 |
+
|
428 |
+
import math
|
429 |
+
from typing import List, Literal, Optional, Tuple
|
430 |
+
|
431 |
+
import numpy as np
|
432 |
+
import torch
|
433 |
+
from jaxtyping import Float
|
434 |
+
from numpy.typing import NDArray
|
435 |
+
from torch import Tensor
|
436 |
+
|
437 |
+
_EPS = np.finfo(float).eps * 4.0
|
438 |
+
|
439 |
+
|
440 |
+
def unit_vector(data: NDArray, axis: Optional[int] = None) -> np.ndarray:
|
441 |
+
"""Return ndarray normalized by length, i.e. Euclidean norm, along axis.
|
442 |
+
|
443 |
+
Args:
|
444 |
+
axis: the axis along which to normalize into unit vector
|
445 |
+
out: where to write out the data to. If None, returns a new np ndarray
|
446 |
+
"""
|
447 |
+
data = np.array(data, dtype=np.float64, copy=True)
|
448 |
+
if data.ndim == 1:
|
449 |
+
data /= math.sqrt(np.dot(data, data))
|
450 |
+
return data
|
451 |
+
length = np.atleast_1d(np.sum(data * data, axis))
|
452 |
+
np.sqrt(length, length)
|
453 |
+
if axis is not None:
|
454 |
+
length = np.expand_dims(length, axis)
|
455 |
+
data /= length
|
456 |
+
return data
|
457 |
+
|
458 |
+
|
459 |
+
def quaternion_from_matrix(matrix: NDArray, isprecise: bool = False) -> np.ndarray:
|
460 |
+
"""Return quaternion from rotation matrix.
|
461 |
+
|
462 |
+
Args:
|
463 |
+
matrix: rotation matrix to obtain quaternion
|
464 |
+
isprecise: if True, input matrix is assumed to be precise rotation matrix and a faster algorithm is used.
|
465 |
+
"""
|
466 |
+
M = np.array(matrix, dtype=np.float64, copy=False)[:4, :4]
|
467 |
+
if isprecise:
|
468 |
+
q = np.empty((4,))
|
469 |
+
t = np.trace(M)
|
470 |
+
if t > M[3, 3]:
|
471 |
+
q[0] = t
|
472 |
+
q[3] = M[1, 0] - M[0, 1]
|
473 |
+
q[2] = M[0, 2] - M[2, 0]
|
474 |
+
q[1] = M[2, 1] - M[1, 2]
|
475 |
+
else:
|
476 |
+
i, j, k = 1, 2, 3
|
477 |
+
if M[1, 1] > M[0, 0]:
|
478 |
+
i, j, k = 2, 3, 1
|
479 |
+
if M[2, 2] > M[i, i]:
|
480 |
+
i, j, k = 3, 1, 2
|
481 |
+
t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3]
|
482 |
+
q[i] = t
|
483 |
+
q[j] = M[i, j] + M[j, i]
|
484 |
+
q[k] = M[k, i] + M[i, k]
|
485 |
+
q[3] = M[k, j] - M[j, k]
|
486 |
+
q *= 0.5 / math.sqrt(t * M[3, 3])
|
487 |
+
else:
|
488 |
+
m00 = M[0, 0]
|
489 |
+
m01 = M[0, 1]
|
490 |
+
m02 = M[0, 2]
|
491 |
+
m10 = M[1, 0]
|
492 |
+
m11 = M[1, 1]
|
493 |
+
m12 = M[1, 2]
|
494 |
+
m20 = M[2, 0]
|
495 |
+
m21 = M[2, 1]
|
496 |
+
m22 = M[2, 2]
|
497 |
+
# symmetric matrix K
|
498 |
+
K = [
|
499 |
+
[m00 - m11 - m22, 0.0, 0.0, 0.0],
|
500 |
+
[m01 + m10, m11 - m00 - m22, 0.0, 0.0],
|
501 |
+
[m02 + m20, m12 + m21, m22 - m00 - m11, 0.0],
|
502 |
+
[m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22],
|
503 |
+
]
|
504 |
+
K = np.array(K)
|
505 |
+
K /= 3.0
|
506 |
+
# quaternion is eigenvector of K that corresponds to largest eigenvalue
|
507 |
+
w, V = np.linalg.eigh(K)
|
508 |
+
q = V[np.array([3, 0, 1, 2]), np.argmax(w)]
|
509 |
+
if q[0] < 0.0:
|
510 |
+
np.negative(q, q)
|
511 |
+
return q
|
512 |
+
|
513 |
+
|
514 |
+
def quaternion_slerp(
|
515 |
+
quat0: NDArray, quat1: NDArray, fraction: float, spin: int = 0, shortestpath: bool = True
|
516 |
+
) -> np.ndarray:
|
517 |
+
"""Return spherical linear interpolation between two quaternions.
|
518 |
+
Args:
|
519 |
+
quat0: first quaternion
|
520 |
+
quat1: second quaternion
|
521 |
+
fraction: how much to interpolate between quat0 vs quat1 (if 0, closer to quat0; if 1, closer to quat1)
|
522 |
+
spin: how much of an additional spin to place on the interpolation
|
523 |
+
shortestpath: whether to return the short or long path to rotation
|
524 |
+
"""
|
525 |
+
q0 = unit_vector(quat0[:4])
|
526 |
+
q1 = unit_vector(quat1[:4])
|
527 |
+
if q0 is None or q1 is None:
|
528 |
+
raise ValueError("Input quaternions invalid.")
|
529 |
+
if fraction == 0.0:
|
530 |
+
return q0
|
531 |
+
if fraction == 1.0:
|
532 |
+
return q1
|
533 |
+
d = np.dot(q0, q1)
|
534 |
+
if abs(abs(d) - 1.0) < _EPS:
|
535 |
+
return q0
|
536 |
+
if shortestpath and d < 0.0:
|
537 |
+
# invert rotation
|
538 |
+
d = -d
|
539 |
+
np.negative(q1, q1)
|
540 |
+
angle = math.acos(d) + spin * math.pi
|
541 |
+
if abs(angle) < _EPS:
|
542 |
+
return q0
|
543 |
+
isin = 1.0 / math.sin(angle)
|
544 |
+
q0 *= math.sin((1.0 - fraction) * angle) * isin
|
545 |
+
q1 *= math.sin(fraction * angle) * isin
|
546 |
+
q0 += q1
|
547 |
+
return q0
|
548 |
+
|
549 |
+
|
550 |
+
def quaternion_matrix(quaternion: NDArray) -> np.ndarray:
|
551 |
+
"""Return homogeneous rotation matrix from quaternion.
|
552 |
+
|
553 |
+
Args:
|
554 |
+
quaternion: value to convert to matrix
|
555 |
+
"""
|
556 |
+
q = np.array(quaternion, dtype=np.float64, copy=True)
|
557 |
+
n = np.dot(q, q)
|
558 |
+
if n < _EPS:
|
559 |
+
return np.identity(4)
|
560 |
+
q *= math.sqrt(2.0 / n)
|
561 |
+
q = np.outer(q, q)
|
562 |
+
return np.array(
|
563 |
+
[
|
564 |
+
[1.0 - q[2, 2] - q[3, 3], q[1, 2] - q[3, 0], q[1, 3] + q[2, 0], 0.0],
|
565 |
+
[q[1, 2] + q[3, 0], 1.0 - q[1, 1] - q[3, 3], q[2, 3] - q[1, 0], 0.0],
|
566 |
+
[q[1, 3] - q[2, 0], q[2, 3] + q[1, 0], 1.0 - q[1, 1] - q[2, 2], 0.0],
|
567 |
+
[0.0, 0.0, 0.0, 1.0],
|
568 |
+
]
|
569 |
+
)
|
570 |
+
|
571 |
+
|
572 |
+
def get_interpolated_poses(pose_a: NDArray, pose_b: NDArray, steps: int = 10) -> List[float]:
|
573 |
+
"""Return interpolation of poses with specified number of steps.
|
574 |
+
Args:
|
575 |
+
pose_a: first pose
|
576 |
+
pose_b: second pose
|
577 |
+
steps: number of steps the interpolated pose path should contain
|
578 |
+
"""
|
579 |
+
|
580 |
+
quat_a = quaternion_from_matrix(pose_a[:3, :3])
|
581 |
+
quat_b = quaternion_from_matrix(pose_b[:3, :3])
|
582 |
+
|
583 |
+
ts = np.linspace(0, 1, steps)
|
584 |
+
quats = [quaternion_slerp(quat_a, quat_b, t) for t in ts]
|
585 |
+
trans = [(1 - t) * pose_a[:3, 3] + t * pose_b[:3, 3] for t in ts]
|
586 |
+
|
587 |
+
poses_ab = []
|
588 |
+
for quat, tran in zip(quats, trans):
|
589 |
+
pose = np.identity(4)
|
590 |
+
pose[:3, :3] = quaternion_matrix(quat)[:3, :3]
|
591 |
+
pose[:3, 3] = tran
|
592 |
+
poses_ab.append(pose[:3])
|
593 |
+
return poses_ab
|
594 |
+
|
595 |
+
|
596 |
+
def get_interpolated_k(
|
597 |
+
k_a: Float[Tensor, "3 3"], k_b: Float[Tensor, "3 3"], steps: int = 10
|
598 |
+
) -> List[Float[Tensor, "3 4"]]:
|
599 |
+
"""
|
600 |
+
Returns interpolated path between two camera poses with specified number of steps.
|
601 |
+
|
602 |
+
Args:
|
603 |
+
k_a: camera matrix 1
|
604 |
+
k_b: camera matrix 2
|
605 |
+
steps: number of steps the interpolated pose path should contain
|
606 |
+
|
607 |
+
Returns:
|
608 |
+
List of interpolated camera poses
|
609 |
+
"""
|
610 |
+
Ks: List[Float[Tensor, "3 3"]] = []
|
611 |
+
ts = np.linspace(0, 1, steps)
|
612 |
+
for t in ts:
|
613 |
+
new_k = k_a * (1.0 - t) + k_b * t
|
614 |
+
Ks.append(new_k)
|
615 |
+
return Ks
|
616 |
+
|
617 |
+
|
618 |
+
def get_ordered_poses_and_k(
|
619 |
+
poses: Float[Tensor, "num_poses 3 4"],
|
620 |
+
Ks: Float[Tensor, "num_poses 3 3"],
|
621 |
+
) -> Tuple[Float[Tensor, "num_poses 3 4"], Float[Tensor, "num_poses 3 3"]]:
|
622 |
+
"""
|
623 |
+
Returns ordered poses and intrinsics by euclidian distance between poses.
|
624 |
+
|
625 |
+
Args:
|
626 |
+
poses: list of camera poses
|
627 |
+
Ks: list of camera intrinsics
|
628 |
+
|
629 |
+
Returns:
|
630 |
+
tuple of ordered poses and intrinsics
|
631 |
+
|
632 |
+
"""
|
633 |
+
|
634 |
+
poses_num = len(poses)
|
635 |
+
|
636 |
+
ordered_poses = torch.unsqueeze(poses[0], 0)
|
637 |
+
ordered_ks = torch.unsqueeze(Ks[0], 0)
|
638 |
+
|
639 |
+
# remove the first pose from poses
|
640 |
+
poses = poses[1:]
|
641 |
+
Ks = Ks[1:]
|
642 |
+
|
643 |
+
for _ in range(poses_num - 1):
|
644 |
+
distances = torch.norm(ordered_poses[-1][:, 3] - poses[:, :, 3], dim=1)
|
645 |
+
idx = torch.argmin(distances)
|
646 |
+
ordered_poses = torch.cat((ordered_poses, torch.unsqueeze(poses[idx], 0)), dim=0)
|
647 |
+
ordered_ks = torch.cat((ordered_ks, torch.unsqueeze(Ks[idx], 0)), dim=0)
|
648 |
+
poses = torch.cat((poses[0:idx], poses[idx + 1 :]), dim=0)
|
649 |
+
Ks = torch.cat((Ks[0:idx], Ks[idx + 1 :]), dim=0)
|
650 |
+
|
651 |
+
return ordered_poses, ordered_ks
|
652 |
+
|
653 |
+
|
654 |
+
def get_interpolated_poses_many(
|
655 |
+
poses: Float[Tensor, "num_poses 3 4"],
|
656 |
+
Ks: Float[Tensor, "num_poses 3 3"],
|
657 |
+
steps_per_transition: int = 10,
|
658 |
+
order_poses: bool = False,
|
659 |
+
) -> Tuple[Float[Tensor, "num_poses 3 4"], Float[Tensor, "num_poses 3 3"]]:
|
660 |
+
"""Return interpolated poses for many camera poses.
|
661 |
+
|
662 |
+
Args:
|
663 |
+
poses: list of camera poses
|
664 |
+
Ks: list of camera intrinsics
|
665 |
+
steps_per_transition: number of steps per transition
|
666 |
+
order_poses: whether to order poses by euclidian distance
|
667 |
+
|
668 |
+
Returns:
|
669 |
+
tuple of new poses and intrinsics
|
670 |
+
"""
|
671 |
+
traj = []
|
672 |
+
k_interp = []
|
673 |
+
|
674 |
+
if order_poses:
|
675 |
+
poses, Ks = get_ordered_poses_and_k(poses, Ks)
|
676 |
+
|
677 |
+
for idx in range(poses.shape[0] - 1):
|
678 |
+
pose_a = poses[idx].cpu().numpy()
|
679 |
+
pose_b = poses[idx + 1].cpu().numpy()
|
680 |
+
poses_ab = get_interpolated_poses(pose_a, pose_b, steps=steps_per_transition)
|
681 |
+
traj += poses_ab
|
682 |
+
k_interp += get_interpolated_k(Ks[idx], Ks[idx + 1], steps=steps_per_transition)
|
683 |
+
|
684 |
+
traj = np.stack(traj, axis=0)
|
685 |
+
k_interp = torch.stack(k_interp, dim=0)
|
686 |
+
|
687 |
+
return torch.tensor(traj, dtype=torch.float32), torch.tensor(k_interp, dtype=torch.float32)
|
688 |
+
|
689 |
+
|
690 |
+
def normalize(x: torch.Tensor) -> Float[Tensor, "*batch"]:
|
691 |
+
"""Returns a normalized vector."""
|
692 |
+
return x / torch.linalg.norm(x)
|
693 |
+
|
694 |
+
|
695 |
+
def normalize_with_norm(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
696 |
+
"""Normalize tensor along axis and return normalized value with norms.
|
697 |
+
|
698 |
+
Args:
|
699 |
+
x: tensor to normalize.
|
700 |
+
dim: axis along which to normalize.
|
701 |
+
|
702 |
+
Returns:
|
703 |
+
Tuple of normalized tensor and corresponding norm.
|
704 |
+
"""
|
705 |
+
|
706 |
+
norm = torch.maximum(torch.linalg.vector_norm(x, dim=dim, keepdims=True), torch.tensor([_EPS]).to(x))
|
707 |
+
return x / norm, norm
|
708 |
+
|
709 |
+
|
710 |
+
def viewmatrix(lookat: torch.Tensor, up: torch.Tensor, pos: torch.Tensor) -> Float[Tensor, "*batch"]:
|
711 |
+
"""Returns a camera transformation matrix.
|
712 |
+
|
713 |
+
Args:
|
714 |
+
lookat: The direction the camera is looking.
|
715 |
+
up: The upward direction of the camera.
|
716 |
+
pos: The position of the camera.
|
717 |
+
|
718 |
+
Returns:
|
719 |
+
A camera transformation matrix.
|
720 |
+
"""
|
721 |
+
vec2 = normalize(lookat)
|
722 |
+
vec1_avg = normalize(up)
|
723 |
+
vec0 = normalize(torch.cross(vec1_avg, vec2))
|
724 |
+
vec1 = normalize(torch.cross(vec2, vec0))
|
725 |
+
m = torch.stack([vec0, vec1, vec2, pos], 1)
|
726 |
+
return m
|
727 |
+
|
728 |
+
|
729 |
+
def get_distortion_params(
|
730 |
+
k1: float = 0.0,
|
731 |
+
k2: float = 0.0,
|
732 |
+
k3: float = 0.0,
|
733 |
+
k4: float = 0.0,
|
734 |
+
p1: float = 0.0,
|
735 |
+
p2: float = 0.0,
|
736 |
+
) -> Float[Tensor, "*batch"]:
|
737 |
+
"""Returns a distortion parameters matrix.
|
738 |
+
|
739 |
+
Args:
|
740 |
+
k1: The first radial distortion parameter.
|
741 |
+
k2: The second radial distortion parameter.
|
742 |
+
k3: The third radial distortion parameter.
|
743 |
+
k4: The fourth radial distortion parameter.
|
744 |
+
p1: The first tangential distortion parameter.
|
745 |
+
p2: The second tangential distortion parameter.
|
746 |
+
Returns:
|
747 |
+
torch.Tensor: A distortion parameters matrix.
|
748 |
+
"""
|
749 |
+
return torch.Tensor([k1, k2, k3, k4, p1, p2])
|
750 |
+
|
751 |
+
|
752 |
+
def _compute_residual_and_jacobian(
|
753 |
+
x: torch.Tensor,
|
754 |
+
y: torch.Tensor,
|
755 |
+
xd: torch.Tensor,
|
756 |
+
yd: torch.Tensor,
|
757 |
+
distortion_params: torch.Tensor,
|
758 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
759 |
+
"""Auxiliary function of radial_and_tangential_undistort() that computes residuals and jacobians.
|
760 |
+
Adapted from MultiNeRF:
|
761 |
+
https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/camera_utils.py#L427-L474
|
762 |
+
|
763 |
+
Args:
|
764 |
+
x: The updated x coordinates.
|
765 |
+
y: The updated y coordinates.
|
766 |
+
xd: The distorted x coordinates.
|
767 |
+
yd: The distorted y coordinates.
|
768 |
+
distortion_params: The distortion parameters [k1, k2, k3, k4, p1, p2].
|
769 |
+
|
770 |
+
Returns:
|
771 |
+
The residuals (fx, fy) and jacobians (fx_x, fx_y, fy_x, fy_y).
|
772 |
+
"""
|
773 |
+
|
774 |
+
k1 = distortion_params[..., 0]
|
775 |
+
k2 = distortion_params[..., 1]
|
776 |
+
k3 = distortion_params[..., 2]
|
777 |
+
k4 = distortion_params[..., 3]
|
778 |
+
p1 = distortion_params[..., 4]
|
779 |
+
p2 = distortion_params[..., 5]
|
780 |
+
|
781 |
+
# let r(x, y) = x^2 + y^2;
|
782 |
+
# d(x, y) = 1 + k1 * r(x, y) + k2 * r(x, y) ^2 + k3 * r(x, y)^3 +
|
783 |
+
# k4 * r(x, y)^4;
|
784 |
+
r = x * x + y * y
|
785 |
+
d = 1.0 + r * (k1 + r * (k2 + r * (k3 + r * k4)))
|
786 |
+
|
787 |
+
# The perfect projection is:
|
788 |
+
# xd = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2);
|
789 |
+
# yd = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2);
|
790 |
+
#
|
791 |
+
# Let's define
|
792 |
+
#
|
793 |
+
# fx(x, y) = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2) - xd;
|
794 |
+
# fy(x, y) = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2) - yd;
|
795 |
+
#
|
796 |
+
# We are looking for a solution that satisfies
|
797 |
+
# fx(x, y) = fy(x, y) = 0;
|
798 |
+
fx = d * x + 2 * p1 * x * y + p2 * (r + 2 * x * x) - xd
|
799 |
+
fy = d * y + 2 * p2 * x * y + p1 * (r + 2 * y * y) - yd
|
800 |
+
|
801 |
+
# Compute derivative of d over [x, y]
|
802 |
+
d_r = k1 + r * (2.0 * k2 + r * (3.0 * k3 + r * 4.0 * k4))
|
803 |
+
d_x = 2.0 * x * d_r
|
804 |
+
d_y = 2.0 * y * d_r
|
805 |
+
|
806 |
+
# Compute derivative of fx over x and y.
|
807 |
+
fx_x = d + d_x * x + 2.0 * p1 * y + 6.0 * p2 * x
|
808 |
+
fx_y = d_y * x + 2.0 * p1 * x + 2.0 * p2 * y
|
809 |
+
|
810 |
+
# Compute derivative of fy over x and y.
|
811 |
+
fy_x = d_x * y + 2.0 * p2 * y + 2.0 * p1 * x
|
812 |
+
fy_y = d + d_y * y + 2.0 * p2 * x + 6.0 * p1 * y
|
813 |
+
|
814 |
+
return fx, fy, fx_x, fx_y, fy_x, fy_y
|
815 |
+
|
816 |
+
|
817 |
+
# @torch_compile(dynamic=True, mode="reduce-overhead", backend="eager")
|
818 |
+
def radial_and_tangential_undistort(
|
819 |
+
coords: torch.Tensor,
|
820 |
+
distortion_params: torch.Tensor,
|
821 |
+
eps: float = 1e-3,
|
822 |
+
max_iterations: int = 10,
|
823 |
+
) -> torch.Tensor:
|
824 |
+
"""Computes undistorted coords given opencv distortion parameters.
|
825 |
+
Adapted from MultiNeRF
|
826 |
+
https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/camera_utils.py#L477-L509
|
827 |
+
|
828 |
+
Args:
|
829 |
+
coords: The distorted coordinates.
|
830 |
+
distortion_params: The distortion parameters [k1, k2, k3, k4, p1, p2].
|
831 |
+
eps: The epsilon for the convergence.
|
832 |
+
max_iterations: The maximum number of iterations to perform.
|
833 |
+
|
834 |
+
Returns:
|
835 |
+
The undistorted coordinates.
|
836 |
+
"""
|
837 |
+
|
838 |
+
# Initialize from the distorted point.
|
839 |
+
x = coords[..., 0]
|
840 |
+
y = coords[..., 1]
|
841 |
+
|
842 |
+
for _ in range(max_iterations):
|
843 |
+
fx, fy, fx_x, fx_y, fy_x, fy_y = _compute_residual_and_jacobian(
|
844 |
+
x=x, y=y, xd=coords[..., 0], yd=coords[..., 1], distortion_params=distortion_params
|
845 |
+
)
|
846 |
+
denominator = fy_x * fx_y - fx_x * fy_y
|
847 |
+
x_numerator = fx * fy_y - fy * fx_y
|
848 |
+
y_numerator = fy * fx_x - fx * fy_x
|
849 |
+
step_x = torch.where(torch.abs(denominator) > eps, x_numerator / denominator, torch.zeros_like(denominator))
|
850 |
+
step_y = torch.where(torch.abs(denominator) > eps, y_numerator / denominator, torch.zeros_like(denominator))
|
851 |
+
|
852 |
+
x = x + step_x
|
853 |
+
y = y + step_y
|
854 |
+
|
855 |
+
return torch.stack([x, y], dim=-1)
|
856 |
+
|
857 |
+
|
858 |
+
def rotation_matrix(a: Float[Tensor, "3"], b: Float[Tensor, "3"]) -> Float[Tensor, "3 3"]:
|
859 |
+
"""Compute the rotation matrix that rotates vector a to vector b.
|
860 |
+
|
861 |
+
Args:
|
862 |
+
a: The vector to rotate.
|
863 |
+
b: The vector to rotate to.
|
864 |
+
Returns:
|
865 |
+
The rotation matrix.
|
866 |
+
"""
|
867 |
+
a = a / torch.linalg.norm(a)
|
868 |
+
b = b / torch.linalg.norm(b)
|
869 |
+
v = torch.cross(a, b)
|
870 |
+
c = torch.dot(a, b)
|
871 |
+
# If vectors are exactly opposite, we add a little noise to one of them
|
872 |
+
if c < -1 + 1e-8:
|
873 |
+
eps = (torch.rand(3) - 0.5) * 0.01
|
874 |
+
return rotation_matrix(a + eps, b)
|
875 |
+
s = torch.linalg.norm(v)
|
876 |
+
skew_sym_mat = torch.Tensor(
|
877 |
+
[
|
878 |
+
[0, -v[2], v[1]],
|
879 |
+
[v[2], 0, -v[0]],
|
880 |
+
[-v[1], v[0], 0],
|
881 |
+
]
|
882 |
+
)
|
883 |
+
return torch.eye(3) + skew_sym_mat + skew_sym_mat @ skew_sym_mat * ((1 - c) / (s**2 + 1e-8))
|
884 |
+
|
885 |
+
|
886 |
+
def focus_of_attention(poses: Float[Tensor, "*num_poses 4 4"], initial_focus: Float[Tensor, "3"]) -> Float[Tensor, "3"]:
|
887 |
+
"""Compute the focus of attention of a set of cameras. Only cameras
|
888 |
+
that have the focus of attention in front of them are considered.
|
889 |
+
|
890 |
+
Args:
|
891 |
+
poses: The poses to orient.
|
892 |
+
initial_focus: The 3D point views to decide which cameras are initially activated.
|
893 |
+
|
894 |
+
Returns:
|
895 |
+
The 3D position of the focus of attention.
|
896 |
+
"""
|
897 |
+
# References to the same method in third-party code:
|
898 |
+
# https://github.com/google-research/multinerf/blob/1c8b1c552133cdb2de1c1f3c871b2813f6662265/internal/camera_utils.py#L145
|
899 |
+
# https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/load_llff.py#L197
|
900 |
+
active_directions = -poses[:, :3, 2:3]
|
901 |
+
active_origins = poses[:, :3, 3:4]
|
902 |
+
# initial value for testing if the focus_pt is in front or behind
|
903 |
+
focus_pt = initial_focus
|
904 |
+
# Prune cameras which have the current have the focus_pt behind them.
|
905 |
+
active = torch.sum(active_directions.squeeze(-1) * (focus_pt - active_origins.squeeze(-1)), dim=-1) > 0
|
906 |
+
done = False
|
907 |
+
# We need at least two active cameras, else fallback on the previous solution.
|
908 |
+
# This may be the "poses" solution if no cameras are active on first iteration, e.g.
|
909 |
+
# they are in an outward-looking configuration.
|
910 |
+
while torch.sum(active.int()) > 1 and not done:
|
911 |
+
active_directions = active_directions[active]
|
912 |
+
active_origins = active_origins[active]
|
913 |
+
# https://en.wikipedia.org/wiki/Line–line_intersection#In_more_than_two_dimensions
|
914 |
+
m = torch.eye(3) - active_directions * torch.transpose(active_directions, -2, -1)
|
915 |
+
mt_m = torch.transpose(m, -2, -1) @ m
|
916 |
+
focus_pt = torch.linalg.inv(mt_m.mean(0)) @ (mt_m @ active_origins).mean(0)[:, 0]
|
917 |
+
active = torch.sum(active_directions.squeeze(-1) * (focus_pt - active_origins.squeeze(-1)), dim=-1) > 0
|
918 |
+
if active.all():
|
919 |
+
# the set of active cameras did not change, so we're done.
|
920 |
+
done = True
|
921 |
+
return focus_pt
|
922 |
+
|
923 |
+
|
924 |
+
def auto_orient_and_center_poses(
|
925 |
+
poses: Float[Tensor, "*num_poses 4 4"],
|
926 |
+
method: Literal["pca", "up", "vertical", "none"] = "up",
|
927 |
+
center_method: Literal["poses", "focus", "none"] = "poses",
|
928 |
+
) -> Tuple[Float[Tensor, "*num_poses 3 4"], Float[Tensor, "3 4"]]:
|
929 |
+
"""Orients and centers the poses.
|
930 |
+
|
931 |
+
We provide three methods for orientation:
|
932 |
+
|
933 |
+
- pca: Orient the poses so that the principal directions of the camera centers are aligned
|
934 |
+
with the axes, Z corresponding to the smallest principal component.
|
935 |
+
This method works well when all of the cameras are in the same plane, for example when
|
936 |
+
images are taken using a mobile robot.
|
937 |
+
- up: Orient the poses so that the average up vector is aligned with the z axis.
|
938 |
+
This method works well when images are not at arbitrary angles.
|
939 |
+
- vertical: Orient the poses so that the Z 3D direction projects close to the
|
940 |
+
y axis in images. This method works better if cameras are not all
|
941 |
+
looking in the same 3D direction, which may happen in camera arrays or in LLFF.
|
942 |
+
|
943 |
+
There are two centering methods:
|
944 |
+
|
945 |
+
- poses: The poses are centered around the origin.
|
946 |
+
- focus: The origin is set to the focus of attention of all cameras (the
|
947 |
+
closest point to cameras optical axes). Recommended for inward-looking
|
948 |
+
camera configurations.
|
949 |
+
|
950 |
+
Args:
|
951 |
+
poses: The poses to orient.
|
952 |
+
method: The method to use for orientation.
|
953 |
+
center_method: The method to use to center the poses.
|
954 |
+
|
955 |
+
Returns:
|
956 |
+
Tuple of the oriented poses and the transform matrix.
|
957 |
+
"""
|
958 |
+
|
959 |
+
origins = poses[..., :3, 3]
|
960 |
+
|
961 |
+
mean_origin = torch.mean(origins, dim=0)
|
962 |
+
translation_diff = origins - mean_origin
|
963 |
+
|
964 |
+
if center_method == "poses":
|
965 |
+
translation = mean_origin
|
966 |
+
elif center_method == "focus":
|
967 |
+
translation = focus_of_attention(poses, mean_origin)
|
968 |
+
elif center_method == "none":
|
969 |
+
translation = torch.zeros_like(mean_origin)
|
970 |
+
else:
|
971 |
+
raise ValueError(f"Unknown value for center_method: {center_method}")
|
972 |
+
|
973 |
+
if method == "pca":
|
974 |
+
_, eigvec = torch.linalg.eigh(translation_diff.T @ translation_diff)
|
975 |
+
eigvec = torch.flip(eigvec, dims=(-1,))
|
976 |
+
|
977 |
+
if torch.linalg.det(eigvec) < 0:
|
978 |
+
eigvec[:, 2] = -eigvec[:, 2]
|
979 |
+
|
980 |
+
transform = torch.cat([eigvec, eigvec @ -translation[..., None]], dim=-1)
|
981 |
+
oriented_poses = transform @ poses
|
982 |
+
|
983 |
+
if oriented_poses.mean(dim=0)[2, 1] < 0:
|
984 |
+
oriented_poses[:, 1:3] = -1 * oriented_poses[:, 1:3]
|
985 |
+
elif method in ("up", "vertical"):
|
986 |
+
up = torch.mean(poses[:, :3, 1], dim=0)
|
987 |
+
up = up / torch.linalg.norm(up)
|
988 |
+
if method == "vertical":
|
989 |
+
# If cameras are not all parallel (e.g. not in an LLFF configuration),
|
990 |
+
# we can find the 3D direction that most projects vertically in all
|
991 |
+
# cameras by minimizing ||Xu|| s.t. ||u||=1. This total least squares
|
992 |
+
# problem is solved by SVD.
|
993 |
+
x_axis_matrix = poses[:, :3, 0]
|
994 |
+
_, S, Vh = torch.linalg.svd(x_axis_matrix, full_matrices=False)
|
995 |
+
# Singular values are S_i=||Xv_i|| for each right singular vector v_i.
|
996 |
+
# ||S|| = sqrt(n) because lines of X are all unit vectors and the v_i
|
997 |
+
# are an orthonormal basis.
|
998 |
+
# ||Xv_i|| = sqrt(sum(dot(x_axis_j,v_i)^2)), thus S_i/sqrt(n) is the
|
999 |
+
# RMS of cosines between x axes and v_i. If the second smallest singular
|
1000 |
+
# value corresponds to an angle error less than 10° (cos(80°)=0.17),
|
1001 |
+
# this is probably a degenerate camera configuration (typical values
|
1002 |
+
# are around 5° average error for the true vertical). In this case,
|
1003 |
+
# rather than taking the vector corresponding to the smallest singular
|
1004 |
+
# value, we project the "up" vector on the plane spanned by the two
|
1005 |
+
# best singular vectors. We could also just fallback to the "up"
|
1006 |
+
# solution.
|
1007 |
+
if S[1] > 0.17 * math.sqrt(poses.shape[0]):
|
1008 |
+
# regular non-degenerate configuration
|
1009 |
+
up_vertical = Vh[2, :]
|
1010 |
+
# It may be pointing up or down. Use "up" to disambiguate the sign.
|
1011 |
+
up = up_vertical if torch.dot(up_vertical, up) > 0 else -up_vertical
|
1012 |
+
else:
|
1013 |
+
# Degenerate configuration: project "up" on the plane spanned by
|
1014 |
+
# the last two right singular vectors (which are orthogonal to the
|
1015 |
+
# first). v_0 is a unit vector, no need to divide by its norm when
|
1016 |
+
# projecting.
|
1017 |
+
up = up - Vh[0, :] * torch.dot(up, Vh[0, :])
|
1018 |
+
# re-normalize
|
1019 |
+
up = up / torch.linalg.norm(up)
|
1020 |
+
|
1021 |
+
rotation = rotation_matrix(up, torch.Tensor([0, 0, 1]))
|
1022 |
+
transform = torch.cat([rotation, rotation @ -translation[..., None]], dim=-1)
|
1023 |
+
oriented_poses = transform @ poses
|
1024 |
+
elif method == "none":
|
1025 |
+
transform = torch.eye(4)
|
1026 |
+
transform[:3, 3] = -translation
|
1027 |
+
transform = transform[:3, :]
|
1028 |
+
oriented_poses = transform @ poses
|
1029 |
+
else:
|
1030 |
+
raise ValueError(f"Unknown value for method: {method}")
|
1031 |
+
|
1032 |
+
return oriented_poses, transform
|
1033 |
+
|
1034 |
+
|
1035 |
+
@torch.jit.script
|
1036 |
+
def fisheye624_project(xyz, params):
|
1037 |
+
"""
|
1038 |
+
Batched implementation of the FisheyeRadTanThinPrism (aka Fisheye624) camera
|
1039 |
+
model project() function.
|
1040 |
+
Inputs:
|
1041 |
+
xyz: BxNx3 tensor of 3D points to be projected
|
1042 |
+
params: Bx16 tensor of Fisheye624 parameters formatted like this:
|
1043 |
+
[f_u f_v c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}]
|
1044 |
+
or Bx15 tensor of Fisheye624 parameters formatted like this:
|
1045 |
+
[f c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}]
|
1046 |
+
Outputs:
|
1047 |
+
uv: BxNx2 tensor of 2D projections of xyz in image plane
|
1048 |
+
Model for fisheye cameras with radial, tangential, and thin-prism distortion.
|
1049 |
+
This model allows fu != fv.
|
1050 |
+
Specifically, the model is:
|
1051 |
+
uvDistorted = [x_r] + tangentialDistortion + thinPrismDistortion
|
1052 |
+
[y_r]
|
1053 |
+
proj = diag(fu,fv) * uvDistorted + [cu;cv];
|
1054 |
+
where:
|
1055 |
+
a = x/z, b = y/z, r = (a^2+b^2)^(1/2)
|
1056 |
+
th = atan(r)
|
1057 |
+
cosPhi = a/r, sinPhi = b/r
|
1058 |
+
[x_r] = (th+ k0 * th^3 + k1* th^5 + ...) [cosPhi]
|
1059 |
+
[y_r] [sinPhi]
|
1060 |
+
the number of terms in the series is determined by the template parameter numK.
|
1061 |
+
tangentialDistortion = [(2 x_r^2 + rd^2)*p_0 + 2*x_r*y_r*p_1]
|
1062 |
+
[(2 y_r^2 + rd^2)*p_1 + 2*x_r*y_r*p_0]
|
1063 |
+
where rd^2 = x_r^2 + y_r^2
|
1064 |
+
thinPrismDistortion = [s0 * rd^2 + s1 rd^4]
|
1065 |
+
[s2 * rd^2 + s3 rd^4]
|
1066 |
+
Author: Daniel DeTone ([email protected])
|
1067 |
+
"""
|
1068 |
+
|
1069 |
+
assert xyz.ndim == 3
|
1070 |
+
assert params.ndim == 2
|
1071 |
+
assert params.shape[-1] == 16 or params.shape[-1] == 15, "This model allows fx != fy"
|
1072 |
+
eps = 1e-9
|
1073 |
+
B, N = xyz.shape[0], xyz.shape[1]
|
1074 |
+
|
1075 |
+
# Radial correction.
|
1076 |
+
z = xyz[:, :, 2].reshape(B, N, 1)
|
1077 |
+
z = torch.where(torch.abs(z) < eps, eps * torch.sign(z), z)
|
1078 |
+
ab = xyz[:, :, :2] / z
|
1079 |
+
r = torch.norm(ab, dim=-1, p=2, keepdim=True)
|
1080 |
+
th = torch.atan(r)
|
1081 |
+
th_divr = torch.where(r < eps, torch.ones_like(ab), ab / r)
|
1082 |
+
th_k = th.reshape(B, N, 1).clone()
|
1083 |
+
for i in range(6):
|
1084 |
+
th_k = th_k + params[:, -12 + i].reshape(B, 1, 1) * torch.pow(th, 3 + i * 2)
|
1085 |
+
xr_yr = th_k * th_divr
|
1086 |
+
uv_dist = xr_yr
|
1087 |
+
|
1088 |
+
# Tangential correction.
|
1089 |
+
p0 = params[:, -6].reshape(B, 1)
|
1090 |
+
p1 = params[:, -5].reshape(B, 1)
|
1091 |
+
xr = xr_yr[:, :, 0].reshape(B, N)
|
1092 |
+
yr = xr_yr[:, :, 1].reshape(B, N)
|
1093 |
+
xr_yr_sq = torch.square(xr_yr)
|
1094 |
+
xr_sq = xr_yr_sq[:, :, 0].reshape(B, N)
|
1095 |
+
yr_sq = xr_yr_sq[:, :, 1].reshape(B, N)
|
1096 |
+
rd_sq = xr_sq + yr_sq
|
1097 |
+
uv_dist_tu = uv_dist[:, :, 0] + ((2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1)
|
1098 |
+
uv_dist_tv = uv_dist[:, :, 1] + ((2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0)
|
1099 |
+
uv_dist = torch.stack([uv_dist_tu, uv_dist_tv], dim=-1) # Avoids in-place complaint.
|
1100 |
+
|
1101 |
+
# Thin Prism correction.
|
1102 |
+
s0 = params[:, -4].reshape(B, 1)
|
1103 |
+
s1 = params[:, -3].reshape(B, 1)
|
1104 |
+
s2 = params[:, -2].reshape(B, 1)
|
1105 |
+
s3 = params[:, -1].reshape(B, 1)
|
1106 |
+
rd_4 = torch.square(rd_sq)
|
1107 |
+
uv_dist[:, :, 0] = uv_dist[:, :, 0] + (s0 * rd_sq + s1 * rd_4)
|
1108 |
+
uv_dist[:, :, 1] = uv_dist[:, :, 1] + (s2 * rd_sq + s3 * rd_4)
|
1109 |
+
|
1110 |
+
# Finally, apply standard terms: focal length and camera centers.
|
1111 |
+
if params.shape[-1] == 15:
|
1112 |
+
fx_fy = params[:, 0].reshape(B, 1, 1)
|
1113 |
+
cx_cy = params[:, 1:3].reshape(B, 1, 2)
|
1114 |
+
else:
|
1115 |
+
fx_fy = params[:, 0:2].reshape(B, 1, 2)
|
1116 |
+
cx_cy = params[:, 2:4].reshape(B, 1, 2)
|
1117 |
+
result = uv_dist * fx_fy + cx_cy
|
1118 |
+
|
1119 |
+
return result
|
1120 |
+
|
1121 |
+
|
1122 |
+
# Core implementation of fisheye 624 unprojection. More details are documented here:
|
1123 |
+
# https://facebookresearch.github.io/projectaria_tools/docs/tech_insights/camera_intrinsic_models#the-fisheye62-model
|
1124 |
+
@torch.jit.script
|
1125 |
+
def fisheye624_unproject_helper(uv, params, max_iters: int = 5):
|
1126 |
+
"""
|
1127 |
+
Batched implementation of the FisheyeRadTanThinPrism (aka Fisheye624) camera
|
1128 |
+
model. There is no analytical solution for the inverse of the project()
|
1129 |
+
function so this solves an optimization problem using Newton's method to get
|
1130 |
+
the inverse.
|
1131 |
+
Inputs:
|
1132 |
+
uv: BxNx2 tensor of 2D pixels to be unprojected
|
1133 |
+
params: Bx16 tensor of Fisheye624 parameters formatted like this:
|
1134 |
+
[f_u f_v c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}]
|
1135 |
+
or Bx15 tensor of Fisheye624 parameters formatted like this:
|
1136 |
+
[f c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}]
|
1137 |
+
Outputs:
|
1138 |
+
xyz: BxNx3 tensor of 3D rays of uv points with z = 1.
|
1139 |
+
Model for fisheye cameras with radial, tangential, and thin-prism distortion.
|
1140 |
+
This model assumes fu=fv. This unproject function holds that:
|
1141 |
+
X = unproject(project(X)) [for X=(x,y,z) in R^3, z>0]
|
1142 |
+
and
|
1143 |
+
x = project(unproject(s*x)) [for s!=0 and x=(u,v) in R^2]
|
1144 |
+
Author: Daniel DeTone ([email protected])
|
1145 |
+
"""
|
1146 |
+
|
1147 |
+
assert uv.ndim == 3, "Expected batched input shaped BxNx3"
|
1148 |
+
assert params.ndim == 2
|
1149 |
+
assert params.shape[-1] == 16 or params.shape[-1] == 15, "This model allows fx != fy"
|
1150 |
+
eps = 1e-6
|
1151 |
+
B, N = uv.shape[0], uv.shape[1]
|
1152 |
+
|
1153 |
+
if params.shape[-1] == 15:
|
1154 |
+
fx_fy = params[:, 0].reshape(B, 1, 1)
|
1155 |
+
cx_cy = params[:, 1:3].reshape(B, 1, 2)
|
1156 |
+
else:
|
1157 |
+
fx_fy = params[:, 0:2].reshape(B, 1, 2)
|
1158 |
+
cx_cy = params[:, 2:4].reshape(B, 1, 2)
|
1159 |
+
|
1160 |
+
uv_dist = (uv - cx_cy) / fx_fy
|
1161 |
+
|
1162 |
+
# Compute xr_yr using Newton's method.
|
1163 |
+
xr_yr = uv_dist.clone() # Initial guess.
|
1164 |
+
for _ in range(max_iters):
|
1165 |
+
uv_dist_est = xr_yr.clone()
|
1166 |
+
# Tangential terms.
|
1167 |
+
p0 = params[:, -6].reshape(B, 1)
|
1168 |
+
p1 = params[:, -5].reshape(B, 1)
|
1169 |
+
xr = xr_yr[:, :, 0].reshape(B, N)
|
1170 |
+
yr = xr_yr[:, :, 1].reshape(B, N)
|
1171 |
+
xr_yr_sq = torch.square(xr_yr)
|
1172 |
+
xr_sq = xr_yr_sq[:, :, 0].reshape(B, N)
|
1173 |
+
yr_sq = xr_yr_sq[:, :, 1].reshape(B, N)
|
1174 |
+
rd_sq = xr_sq + yr_sq
|
1175 |
+
uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + ((2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1)
|
1176 |
+
uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + ((2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0)
|
1177 |
+
# Thin Prism terms.
|
1178 |
+
s0 = params[:, -4].reshape(B, 1)
|
1179 |
+
s1 = params[:, -3].reshape(B, 1)
|
1180 |
+
s2 = params[:, -2].reshape(B, 1)
|
1181 |
+
s3 = params[:, -1].reshape(B, 1)
|
1182 |
+
rd_4 = torch.square(rd_sq)
|
1183 |
+
uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + (s0 * rd_sq + s1 * rd_4)
|
1184 |
+
uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + (s2 * rd_sq + s3 * rd_4)
|
1185 |
+
# Compute the derivative of uv_dist w.r.t. xr_yr.
|
1186 |
+
duv_dist_dxr_yr = uv.new_ones(B, N, 2, 2)
|
1187 |
+
duv_dist_dxr_yr[:, :, 0, 0] = 1.0 + 6.0 * xr_yr[:, :, 0] * p0 + 2.0 * xr_yr[:, :, 1] * p1
|
1188 |
+
offdiag = 2.0 * (xr_yr[:, :, 0] * p1 + xr_yr[:, :, 1] * p0)
|
1189 |
+
duv_dist_dxr_yr[:, :, 0, 1] = offdiag
|
1190 |
+
duv_dist_dxr_yr[:, :, 1, 0] = offdiag
|
1191 |
+
duv_dist_dxr_yr[:, :, 1, 1] = 1.0 + 6.0 * xr_yr[:, :, 1] * p1 + 2.0 * xr_yr[:, :, 0] * p0
|
1192 |
+
xr_yr_sq_norm = xr_yr_sq[:, :, 0] + xr_yr_sq[:, :, 1]
|
1193 |
+
temp1 = 2.0 * (s0 + 2.0 * s1 * xr_yr_sq_norm)
|
1194 |
+
duv_dist_dxr_yr[:, :, 0, 0] = duv_dist_dxr_yr[:, :, 0, 0] + (xr_yr[:, :, 0] * temp1)
|
1195 |
+
duv_dist_dxr_yr[:, :, 0, 1] = duv_dist_dxr_yr[:, :, 0, 1] + (xr_yr[:, :, 1] * temp1)
|
1196 |
+
temp2 = 2.0 * (s2 + 2.0 * s3 * xr_yr_sq_norm)
|
1197 |
+
duv_dist_dxr_yr[:, :, 1, 0] = duv_dist_dxr_yr[:, :, 1, 0] + (xr_yr[:, :, 0] * temp2)
|
1198 |
+
duv_dist_dxr_yr[:, :, 1, 1] = duv_dist_dxr_yr[:, :, 1, 1] + (xr_yr[:, :, 1] * temp2)
|
1199 |
+
# Compute 2x2 inverse manually here since torch.inverse() is very slow.
|
1200 |
+
# Because this is slow: inv = duv_dist_dxr_yr.inverse()
|
1201 |
+
# About a 10x reduction in speed with above line.
|
1202 |
+
mat = duv_dist_dxr_yr.reshape(-1, 2, 2)
|
1203 |
+
a = mat[:, 0, 0].reshape(-1, 1, 1)
|
1204 |
+
b = mat[:, 0, 1].reshape(-1, 1, 1)
|
1205 |
+
c = mat[:, 1, 0].reshape(-1, 1, 1)
|
1206 |
+
d = mat[:, 1, 1].reshape(-1, 1, 1)
|
1207 |
+
det = 1.0 / ((a * d) - (b * c))
|
1208 |
+
top = torch.cat([d, -b], dim=2)
|
1209 |
+
bot = torch.cat([-c, a], dim=2)
|
1210 |
+
inv = det * torch.cat([top, bot], dim=1)
|
1211 |
+
inv = inv.reshape(B, N, 2, 2)
|
1212 |
+
# Manually compute 2x2 @ 2x1 matrix multiply.
|
1213 |
+
# Because this is slow: step = (inv @ (uv_dist - uv_dist_est)[..., None])[..., 0]
|
1214 |
+
diff = uv_dist - uv_dist_est
|
1215 |
+
a = inv[:, :, 0, 0]
|
1216 |
+
b = inv[:, :, 0, 1]
|
1217 |
+
c = inv[:, :, 1, 0]
|
1218 |
+
d = inv[:, :, 1, 1]
|
1219 |
+
e = diff[:, :, 0]
|
1220 |
+
f = diff[:, :, 1]
|
1221 |
+
step = torch.stack([a * e + b * f, c * e + d * f], dim=-1)
|
1222 |
+
# Newton step.
|
1223 |
+
xr_yr = xr_yr + step
|
1224 |
+
|
1225 |
+
# Compute theta using Newton's method.
|
1226 |
+
xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1)
|
1227 |
+
th = xr_yr_norm.clone()
|
1228 |
+
for _ in range(max_iters):
|
1229 |
+
th_radial = uv.new_ones(B, N, 1)
|
1230 |
+
dthd_th = uv.new_ones(B, N, 1)
|
1231 |
+
for k in range(6):
|
1232 |
+
r_k = params[:, -12 + k].reshape(B, 1, 1)
|
1233 |
+
th_radial = th_radial + (r_k * torch.pow(th, 2 + k * 2))
|
1234 |
+
dthd_th = dthd_th + ((3.0 + 2.0 * k) * r_k * torch.pow(th, 2 + k * 2))
|
1235 |
+
th_radial = th_radial * th
|
1236 |
+
step = (xr_yr_norm - th_radial) / dthd_th
|
1237 |
+
# handle dthd_th close to 0.
|
1238 |
+
step = torch.where(dthd_th.abs() > eps, step, torch.sign(step) * eps * 10.0)
|
1239 |
+
th = th + step
|
1240 |
+
# Compute the ray direction using theta and xr_yr.
|
1241 |
+
close_to_zero = torch.logical_and(th.abs() < eps, xr_yr_norm.abs() < eps)
|
1242 |
+
ray_dir = torch.where(close_to_zero, xr_yr, torch.tan(th) / xr_yr_norm * xr_yr)
|
1243 |
+
ray = torch.cat([ray_dir, uv.new_ones(B, N, 1)], dim=2)
|
1244 |
+
return ray
|
1245 |
+
|
1246 |
+
|
1247 |
+
# unproject 2D point to 3D with fisheye624 model
|
1248 |
+
def fisheye624_unproject(coords: torch.Tensor, distortion_params: torch.Tensor) -> torch.Tensor:
|
1249 |
+
dirs = fisheye624_unproject_helper(coords.unsqueeze(0), distortion_params[0].unsqueeze(0))
|
1250 |
+
# correct for camera space differences:
|
1251 |
+
dirs[..., 1] = -dirs[..., 1]
|
1252 |
+
dirs[..., 2] = -dirs[..., 2]
|
1253 |
+
return dirs
|
sgm/data/cifar10.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torchvision
|
3 |
+
from torch.utils.data import DataLoader, Dataset
|
4 |
+
from torchvision import transforms
|
5 |
+
|
6 |
+
|
7 |
+
class CIFAR10DataDictWrapper(Dataset):
|
8 |
+
def __init__(self, dset):
|
9 |
+
super().__init__()
|
10 |
+
self.dset = dset
|
11 |
+
|
12 |
+
def __getitem__(self, i):
|
13 |
+
x, y = self.dset[i]
|
14 |
+
return {"jpg": x, "cls": y}
|
15 |
+
|
16 |
+
def __len__(self):
|
17 |
+
return len(self.dset)
|
18 |
+
|
19 |
+
|
20 |
+
class CIFAR10Loader(pl.LightningDataModule):
|
21 |
+
def __init__(self, batch_size, num_workers=0, shuffle=True):
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
transform = transforms.Compose(
|
25 |
+
[transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
|
26 |
+
)
|
27 |
+
|
28 |
+
self.batch_size = batch_size
|
29 |
+
self.num_workers = num_workers
|
30 |
+
self.shuffle = shuffle
|
31 |
+
self.train_dataset = CIFAR10DataDictWrapper(
|
32 |
+
torchvision.datasets.CIFAR10(
|
33 |
+
root=".data/", train=True, download=True, transform=transform
|
34 |
+
)
|
35 |
+
)
|
36 |
+
self.test_dataset = CIFAR10DataDictWrapper(
|
37 |
+
torchvision.datasets.CIFAR10(
|
38 |
+
root=".data/", train=False, download=True, transform=transform
|
39 |
+
)
|
40 |
+
)
|
41 |
+
|
42 |
+
def prepare_data(self):
|
43 |
+
pass
|
44 |
+
|
45 |
+
def train_dataloader(self):
|
46 |
+
return DataLoader(
|
47 |
+
self.train_dataset,
|
48 |
+
batch_size=self.batch_size,
|
49 |
+
shuffle=self.shuffle,
|
50 |
+
num_workers=self.num_workers,
|
51 |
+
)
|
52 |
+
|
53 |
+
def test_dataloader(self):
|
54 |
+
return DataLoader(
|
55 |
+
self.test_dataset,
|
56 |
+
batch_size=self.batch_size,
|
57 |
+
shuffle=self.shuffle,
|
58 |
+
num_workers=self.num_workers,
|
59 |
+
)
|
60 |
+
|
61 |
+
def val_dataloader(self):
|
62 |
+
return DataLoader(
|
63 |
+
self.test_dataset,
|
64 |
+
batch_size=self.batch_size,
|
65 |
+
shuffle=self.shuffle,
|
66 |
+
num_workers=self.num_workers,
|
67 |
+
)
|
sgm/data/co3d.py
ADDED
@@ -0,0 +1,1367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
adopted from SparseFusion
|
3 |
+
Wrapper for the full CO3Dv2 dataset
|
4 |
+
#@ Modified from https://github.com/facebookresearch/pytorch3d
|
5 |
+
"""
|
6 |
+
|
7 |
+
import json
|
8 |
+
import logging
|
9 |
+
import math
|
10 |
+
import os
|
11 |
+
import random
|
12 |
+
import time
|
13 |
+
import warnings
|
14 |
+
from collections import defaultdict
|
15 |
+
from itertools import islice
|
16 |
+
from typing import (
|
17 |
+
Any,
|
18 |
+
ClassVar,
|
19 |
+
List,
|
20 |
+
Mapping,
|
21 |
+
Optional,
|
22 |
+
Sequence,
|
23 |
+
Tuple,
|
24 |
+
Type,
|
25 |
+
TypedDict,
|
26 |
+
Union,
|
27 |
+
)
|
28 |
+
from einops import rearrange, repeat
|
29 |
+
|
30 |
+
import numpy as np
|
31 |
+
import torch
|
32 |
+
import torch.nn.functional as F
|
33 |
+
import torchvision.transforms.functional as TF
|
34 |
+
from pytorch3d.utils import opencv_from_cameras_projection
|
35 |
+
from pytorch3d.implicitron.dataset import types
|
36 |
+
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
37 |
+
from sgm.data.json_index_dataset import (
|
38 |
+
FrameAnnotsEntry,
|
39 |
+
_bbox_xywh_to_xyxy,
|
40 |
+
_bbox_xyxy_to_xywh,
|
41 |
+
_clamp_box_to_image_bounds_and_round,
|
42 |
+
_crop_around_box,
|
43 |
+
_get_1d_bounds,
|
44 |
+
_get_bbox_from_mask,
|
45 |
+
_get_clamp_bbox,
|
46 |
+
_load_1bit_png_mask,
|
47 |
+
_load_16big_png_depth,
|
48 |
+
_load_depth,
|
49 |
+
_load_depth_mask,
|
50 |
+
_load_image,
|
51 |
+
_load_mask,
|
52 |
+
_load_pointcloud,
|
53 |
+
_rescale_bbox,
|
54 |
+
_safe_as_tensor,
|
55 |
+
_seq_name_to_seed,
|
56 |
+
)
|
57 |
+
from sgm.data.objaverse import video_collate_fn
|
58 |
+
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import (
|
59 |
+
get_available_subset_names,
|
60 |
+
)
|
61 |
+
from pytorch3d.renderer.cameras import PerspectiveCameras
|
62 |
+
|
63 |
+
logger = logging.getLogger(__name__)
|
64 |
+
|
65 |
+
|
66 |
+
from dataclasses import dataclass, field, fields
|
67 |
+
|
68 |
+
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
69 |
+
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
70 |
+
from pytorch3d.structures.pointclouds import Pointclouds, join_pointclouds_as_batch
|
71 |
+
from pytorch_lightning import LightningDataModule
|
72 |
+
from torch.utils.data import DataLoader
|
73 |
+
|
74 |
+
CO3D_ALL_CATEGORIES = list(
|
75 |
+
reversed(
|
76 |
+
[
|
77 |
+
"baseballbat",
|
78 |
+
"banana",
|
79 |
+
"bicycle",
|
80 |
+
"microwave",
|
81 |
+
"tv",
|
82 |
+
"cellphone",
|
83 |
+
"toilet",
|
84 |
+
"hairdryer",
|
85 |
+
"couch",
|
86 |
+
"kite",
|
87 |
+
"pizza",
|
88 |
+
"umbrella",
|
89 |
+
"wineglass",
|
90 |
+
"laptop",
|
91 |
+
"hotdog",
|
92 |
+
"stopsign",
|
93 |
+
"frisbee",
|
94 |
+
"baseballglove",
|
95 |
+
"cup",
|
96 |
+
"parkingmeter",
|
97 |
+
"backpack",
|
98 |
+
"toyplane",
|
99 |
+
"toybus",
|
100 |
+
"handbag",
|
101 |
+
"chair",
|
102 |
+
"keyboard",
|
103 |
+
"car",
|
104 |
+
"motorcycle",
|
105 |
+
"carrot",
|
106 |
+
"bottle",
|
107 |
+
"sandwich",
|
108 |
+
"remote",
|
109 |
+
"bowl",
|
110 |
+
"skateboard",
|
111 |
+
"toaster",
|
112 |
+
"mouse",
|
113 |
+
"toytrain",
|
114 |
+
"book",
|
115 |
+
"toytruck",
|
116 |
+
"orange",
|
117 |
+
"broccoli",
|
118 |
+
"plant",
|
119 |
+
"teddybear",
|
120 |
+
"suitcase",
|
121 |
+
"bench",
|
122 |
+
"ball",
|
123 |
+
"cake",
|
124 |
+
"vase",
|
125 |
+
"hydrant",
|
126 |
+
"apple",
|
127 |
+
"donut",
|
128 |
+
]
|
129 |
+
)
|
130 |
+
)
|
131 |
+
|
132 |
+
CO3D_ALL_TEN = [
|
133 |
+
"donut",
|
134 |
+
"apple",
|
135 |
+
"hydrant",
|
136 |
+
"vase",
|
137 |
+
"cake",
|
138 |
+
"ball",
|
139 |
+
"bench",
|
140 |
+
"suitcase",
|
141 |
+
"teddybear",
|
142 |
+
"plant",
|
143 |
+
]
|
144 |
+
|
145 |
+
|
146 |
+
# @ FROM https://github.com/facebookresearch/pytorch3d
|
147 |
+
@dataclass
|
148 |
+
class FrameData(Mapping[str, Any]):
|
149 |
+
"""
|
150 |
+
A type of the elements returned by indexing the dataset object.
|
151 |
+
It can represent both individual frames and batches of thereof;
|
152 |
+
in this documentation, the sizes of tensors refer to single frames;
|
153 |
+
add the first batch dimension for the collation result.
|
154 |
+
Args:
|
155 |
+
frame_number: The number of the frame within its sequence.
|
156 |
+
0-based continuous integers.
|
157 |
+
sequence_name: The unique name of the frame's sequence.
|
158 |
+
sequence_category: The object category of the sequence.
|
159 |
+
frame_timestamp: The time elapsed since the start of a sequence in sec.
|
160 |
+
image_size_hw: The size of the image in pixels; (height, width) tensor
|
161 |
+
of shape (2,).
|
162 |
+
image_path: The qualified path to the loaded image (with dataset_root).
|
163 |
+
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
|
164 |
+
of the frame; elements are floats in [0, 1].
|
165 |
+
mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
|
166 |
+
regions. Regions can be invalid (mask_crop[i,j]=0) in case they
|
167 |
+
are a result of zero-padding of the image after cropping around
|
168 |
+
the object bounding box; elements are floats in {0.0, 1.0}.
|
169 |
+
depth_path: The qualified path to the frame's depth map.
|
170 |
+
depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
|
171 |
+
of the frame; values correspond to distances from the camera;
|
172 |
+
use `depth_mask` and `mask_crop` to filter for valid pixels.
|
173 |
+
depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
|
174 |
+
depth map that are valid for evaluation, they have been checked for
|
175 |
+
consistency across views; elements are floats in {0.0, 1.0}.
|
176 |
+
mask_path: A qualified path to the foreground probability mask.
|
177 |
+
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
|
178 |
+
pixels belonging to the captured object; elements are floats
|
179 |
+
in [0, 1].
|
180 |
+
bbox_xywh: The bounding box tightly enclosing the foreground object in the
|
181 |
+
format (x0, y0, width, height). The convention assumes that
|
182 |
+
`x0+width` and `y0+height` includes the boundary of the box.
|
183 |
+
I.e., to slice out the corresponding crop from an image tensor `I`
|
184 |
+
we execute `crop = I[..., y0:y0+height, x0:x0+width]`
|
185 |
+
crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb`
|
186 |
+
in the original image coordinates in the format (x0, y0, width, height).
|
187 |
+
The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs
|
188 |
+
from `bbox_xywh` due to padding (which can happen e.g. due to
|
189 |
+
setting `JsonIndexDataset.box_crop_context > 0`)
|
190 |
+
camera: A PyTorch3D camera object corresponding the frame's viewpoint,
|
191 |
+
corrected for cropping if it happened.
|
192 |
+
camera_quality_score: The score proportional to the confidence of the
|
193 |
+
frame's camera estimation (the higher the more accurate).
|
194 |
+
point_cloud_quality_score: The score proportional to the accuracy of the
|
195 |
+
frame's sequence point cloud (the higher the more accurate).
|
196 |
+
sequence_point_cloud_path: The path to the sequence's point cloud.
|
197 |
+
sequence_point_cloud: A PyTorch3D Pointclouds object holding the
|
198 |
+
point cloud corresponding to the frame's sequence. When the object
|
199 |
+
represents a batch of frames, point clouds may be deduplicated;
|
200 |
+
see `sequence_point_cloud_idx`.
|
201 |
+
sequence_point_cloud_idx: Integer indices mapping frame indices to the
|
202 |
+
corresponding point clouds in `sequence_point_cloud`; to get the
|
203 |
+
corresponding point cloud to `image_rgb[i]`, use
|
204 |
+
`sequence_point_cloud[sequence_point_cloud_idx[i]]`.
|
205 |
+
frame_type: The type of the loaded frame specified in
|
206 |
+
`subset_lists_file`, if provided.
|
207 |
+
meta: A dict for storing additional frame information.
|
208 |
+
"""
|
209 |
+
|
210 |
+
frame_number: Optional[torch.LongTensor]
|
211 |
+
sequence_name: Union[str, List[str]]
|
212 |
+
sequence_category: Union[str, List[str]]
|
213 |
+
frame_timestamp: Optional[torch.Tensor] = None
|
214 |
+
image_size_hw: Optional[torch.Tensor] = None
|
215 |
+
image_path: Union[str, List[str], None] = None
|
216 |
+
image_rgb: Optional[torch.Tensor] = None
|
217 |
+
# masks out padding added due to cropping the square bit
|
218 |
+
mask_crop: Optional[torch.Tensor] = None
|
219 |
+
depth_path: Union[str, List[str], None] = ""
|
220 |
+
depth_map: Optional[torch.Tensor] = torch.zeros(1)
|
221 |
+
depth_mask: Optional[torch.Tensor] = torch.zeros(1)
|
222 |
+
mask_path: Union[str, List[str], None] = None
|
223 |
+
fg_probability: Optional[torch.Tensor] = None
|
224 |
+
bbox_xywh: Optional[torch.Tensor] = None
|
225 |
+
crop_bbox_xywh: Optional[torch.Tensor] = None
|
226 |
+
camera: Optional[PerspectiveCameras] = None
|
227 |
+
camera_quality_score: Optional[torch.Tensor] = None
|
228 |
+
point_cloud_quality_score: Optional[torch.Tensor] = None
|
229 |
+
sequence_point_cloud_path: Union[str, List[str], None] = ""
|
230 |
+
sequence_point_cloud: Optional[Pointclouds] = torch.zeros(1)
|
231 |
+
sequence_point_cloud_idx: Optional[torch.Tensor] = torch.zeros(1)
|
232 |
+
frame_type: Union[str, List[str], None] = "" # known | unseen
|
233 |
+
meta: dict = field(default_factory=lambda: {})
|
234 |
+
valid_region: Optional[torch.Tensor] = None
|
235 |
+
category_one_hot: Optional[torch.Tensor] = None
|
236 |
+
|
237 |
+
def to(self, *args, **kwargs):
|
238 |
+
new_params = {}
|
239 |
+
for f in fields(self):
|
240 |
+
value = getattr(self, f.name)
|
241 |
+
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
|
242 |
+
new_params[f.name] = value.to(*args, **kwargs)
|
243 |
+
else:
|
244 |
+
new_params[f.name] = value
|
245 |
+
return type(self)(**new_params)
|
246 |
+
|
247 |
+
def cpu(self):
|
248 |
+
return self.to(device=torch.device("cpu"))
|
249 |
+
|
250 |
+
def cuda(self):
|
251 |
+
return self.to(device=torch.device("cuda"))
|
252 |
+
|
253 |
+
# the following functions make sure **frame_data can be passed to functions
|
254 |
+
def __iter__(self):
|
255 |
+
for f in fields(self):
|
256 |
+
yield f.name
|
257 |
+
|
258 |
+
def __getitem__(self, key):
|
259 |
+
return getattr(self, key)
|
260 |
+
|
261 |
+
def __len__(self):
|
262 |
+
return len(fields(self))
|
263 |
+
|
264 |
+
@classmethod
|
265 |
+
def collate(cls, batch):
|
266 |
+
"""
|
267 |
+
Given a list objects `batch` of class `cls`, collates them into a batched
|
268 |
+
representation suitable for processing with deep networks.
|
269 |
+
"""
|
270 |
+
|
271 |
+
elem = batch[0]
|
272 |
+
|
273 |
+
if isinstance(elem, cls):
|
274 |
+
pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
|
275 |
+
id_to_idx = defaultdict(list)
|
276 |
+
for i, pc_id in enumerate(pointcloud_ids):
|
277 |
+
id_to_idx[pc_id].append(i)
|
278 |
+
|
279 |
+
sequence_point_cloud = []
|
280 |
+
sequence_point_cloud_idx = -np.ones((len(batch),))
|
281 |
+
for i, ind in enumerate(id_to_idx.values()):
|
282 |
+
sequence_point_cloud_idx[ind] = i
|
283 |
+
sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
|
284 |
+
assert (sequence_point_cloud_idx >= 0).all()
|
285 |
+
|
286 |
+
override_fields = {
|
287 |
+
"sequence_point_cloud": sequence_point_cloud,
|
288 |
+
"sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
|
289 |
+
}
|
290 |
+
# note that the pre-collate value of sequence_point_cloud_idx is unused
|
291 |
+
|
292 |
+
collated = {}
|
293 |
+
for f in fields(elem):
|
294 |
+
list_values = override_fields.get(
|
295 |
+
f.name, [getattr(d, f.name) for d in batch]
|
296 |
+
)
|
297 |
+
collated[f.name] = (
|
298 |
+
cls.collate(list_values)
|
299 |
+
if all(list_value is not None for list_value in list_values)
|
300 |
+
else None
|
301 |
+
)
|
302 |
+
return cls(**collated)
|
303 |
+
|
304 |
+
elif isinstance(elem, Pointclouds):
|
305 |
+
return join_pointclouds_as_batch(batch)
|
306 |
+
|
307 |
+
elif isinstance(elem, CamerasBase):
|
308 |
+
# TODO: don't store K; enforce working in NDC space
|
309 |
+
return join_cameras_as_batch(batch)
|
310 |
+
else:
|
311 |
+
return torch.utils.data._utils.collate.default_collate(batch)
|
312 |
+
|
313 |
+
|
314 |
+
# @ MODIFIED FROM https://github.com/facebookresearch/pytorch3d
|
315 |
+
class CO3Dv2Wrapper(torch.utils.data.Dataset):
|
316 |
+
def __init__(
|
317 |
+
self,
|
318 |
+
root_dir="/drive/datasets/co3d/",
|
319 |
+
category="hydrant",
|
320 |
+
subset="fewview_train",
|
321 |
+
stage="train",
|
322 |
+
sample_batch_size=20,
|
323 |
+
image_size=256,
|
324 |
+
masked=False,
|
325 |
+
deprecated_val_region=False,
|
326 |
+
return_frame_data_list=False,
|
327 |
+
reso: int = 256,
|
328 |
+
mask_type: str = "random",
|
329 |
+
cond_aug_mean=-3.0,
|
330 |
+
cond_aug_std=0.5,
|
331 |
+
condition_on_elevation=False,
|
332 |
+
fps_id=0.0,
|
333 |
+
motion_bucket_id=300.0,
|
334 |
+
num_frames: int = 20,
|
335 |
+
use_mask: bool = True,
|
336 |
+
load_pixelnerf: bool = True,
|
337 |
+
scale_pose: bool = True,
|
338 |
+
max_n_cond: int = 5,
|
339 |
+
min_n_cond: int = 2,
|
340 |
+
cond_on_multi: bool = False,
|
341 |
+
):
|
342 |
+
root = root_dir
|
343 |
+
from typing import List
|
344 |
+
|
345 |
+
from co3d.dataset.data_types import (
|
346 |
+
FrameAnnotation,
|
347 |
+
SequenceAnnotation,
|
348 |
+
load_dataclass_jgzip,
|
349 |
+
)
|
350 |
+
|
351 |
+
self.dataset_root = root
|
352 |
+
self.path_manager = None
|
353 |
+
self.subset = subset
|
354 |
+
self.stage = stage
|
355 |
+
self.subset_lists_file: List[str] = [
|
356 |
+
f"{self.dataset_root}/{category}/set_lists/set_lists_{subset}.json"
|
357 |
+
]
|
358 |
+
self.subsets: Optional[List[str]] = [subset]
|
359 |
+
self.sample_batch_size = sample_batch_size
|
360 |
+
self.limit_to: int = 0
|
361 |
+
self.limit_sequences_to: int = 0
|
362 |
+
self.pick_sequence: Tuple[str, ...] = ()
|
363 |
+
self.exclude_sequence: Tuple[str, ...] = ()
|
364 |
+
self.limit_category_to: Tuple[int, ...] = ()
|
365 |
+
self.load_images: bool = True
|
366 |
+
self.load_depths: bool = False
|
367 |
+
self.load_depth_masks: bool = False
|
368 |
+
self.load_masks: bool = True
|
369 |
+
self.load_point_clouds: bool = False
|
370 |
+
self.max_points: int = 0
|
371 |
+
self.mask_images: bool = False
|
372 |
+
self.mask_depths: bool = False
|
373 |
+
self.image_height: Optional[int] = image_size
|
374 |
+
self.image_width: Optional[int] = image_size
|
375 |
+
self.box_crop: bool = True
|
376 |
+
self.box_crop_mask_thr: float = 0.4
|
377 |
+
self.box_crop_context: float = 0.3
|
378 |
+
self.remove_empty_masks: bool = True
|
379 |
+
self.n_frames_per_sequence: int = -1
|
380 |
+
self.seed: int = 0
|
381 |
+
self.sort_frames: bool = False
|
382 |
+
self.eval_batches: Any = None
|
383 |
+
|
384 |
+
self.img_h = self.image_height
|
385 |
+
self.img_w = self.image_width
|
386 |
+
self.masked = masked
|
387 |
+
self.deprecated_val_region = deprecated_val_region
|
388 |
+
self.return_frame_data_list = return_frame_data_list
|
389 |
+
|
390 |
+
self.reso = reso
|
391 |
+
self.num_frames = num_frames
|
392 |
+
self.cond_aug_mean = cond_aug_mean
|
393 |
+
self.cond_aug_std = cond_aug_std
|
394 |
+
self.condition_on_elevation = condition_on_elevation
|
395 |
+
self.fps_id = fps_id
|
396 |
+
self.motion_bucket_id = motion_bucket_id
|
397 |
+
self.mask_type = mask_type
|
398 |
+
self.use_mask = use_mask
|
399 |
+
self.load_pixelnerf = load_pixelnerf
|
400 |
+
self.scale_pose = scale_pose
|
401 |
+
self.max_n_cond = max_n_cond
|
402 |
+
self.min_n_cond = min_n_cond
|
403 |
+
self.cond_on_multi = cond_on_multi
|
404 |
+
|
405 |
+
if self.cond_on_multi:
|
406 |
+
assert self.min_n_cond == self.max_n_cond
|
407 |
+
|
408 |
+
start_time = time.time()
|
409 |
+
if "all_" in category or category == "all":
|
410 |
+
self.category_frame_annotations = []
|
411 |
+
self.category_sequence_annotations = []
|
412 |
+
self.subset_lists_file = []
|
413 |
+
|
414 |
+
if category == "all":
|
415 |
+
cats = CO3D_ALL_CATEGORIES
|
416 |
+
elif category == "all_four":
|
417 |
+
cats = ["hydrant", "teddybear", "motorcycle", "bench"]
|
418 |
+
elif category == "all_ten":
|
419 |
+
cats = [
|
420 |
+
"donut",
|
421 |
+
"apple",
|
422 |
+
"hydrant",
|
423 |
+
"vase",
|
424 |
+
"cake",
|
425 |
+
"ball",
|
426 |
+
"bench",
|
427 |
+
"suitcase",
|
428 |
+
"teddybear",
|
429 |
+
"plant",
|
430 |
+
]
|
431 |
+
elif category == "all_15":
|
432 |
+
cats = [
|
433 |
+
"hydrant",
|
434 |
+
"teddybear",
|
435 |
+
"motorcycle",
|
436 |
+
"bench",
|
437 |
+
"hotdog",
|
438 |
+
"remote",
|
439 |
+
"suitcase",
|
440 |
+
"donut",
|
441 |
+
"plant",
|
442 |
+
"toaster",
|
443 |
+
"keyboard",
|
444 |
+
"handbag",
|
445 |
+
"toyplane",
|
446 |
+
"tv",
|
447 |
+
"orange",
|
448 |
+
]
|
449 |
+
else:
|
450 |
+
print("UNSPECIFIED CATEGORY SUBSET")
|
451 |
+
cats = ["hydrant", "teddybear"]
|
452 |
+
print("loading", cats)
|
453 |
+
for cat in cats:
|
454 |
+
self.category_frame_annotations.extend(
|
455 |
+
load_dataclass_jgzip(
|
456 |
+
f"{self.dataset_root}/{cat}/frame_annotations.jgz",
|
457 |
+
List[FrameAnnotation],
|
458 |
+
)
|
459 |
+
)
|
460 |
+
self.category_sequence_annotations.extend(
|
461 |
+
load_dataclass_jgzip(
|
462 |
+
f"{self.dataset_root}/{cat}/sequence_annotations.jgz",
|
463 |
+
List[SequenceAnnotation],
|
464 |
+
)
|
465 |
+
)
|
466 |
+
self.subset_lists_file.append(
|
467 |
+
f"{self.dataset_root}/{cat}/set_lists/set_lists_{subset}.json"
|
468 |
+
)
|
469 |
+
|
470 |
+
else:
|
471 |
+
self.category_frame_annotations = load_dataclass_jgzip(
|
472 |
+
f"{self.dataset_root}/{category}/frame_annotations.jgz",
|
473 |
+
List[FrameAnnotation],
|
474 |
+
)
|
475 |
+
self.category_sequence_annotations = load_dataclass_jgzip(
|
476 |
+
f"{self.dataset_root}/{category}/sequence_annotations.jgz",
|
477 |
+
List[SequenceAnnotation],
|
478 |
+
)
|
479 |
+
|
480 |
+
self.subset_to_image_path = None
|
481 |
+
self._load_frames()
|
482 |
+
self._load_sequences()
|
483 |
+
self._sort_frames()
|
484 |
+
self._load_subset_lists()
|
485 |
+
self._filter_db() # also computes sequence indices
|
486 |
+
# self._extract_and_set_eval_batches()
|
487 |
+
# print(self.eval_batches)
|
488 |
+
logger.info(str(self))
|
489 |
+
|
490 |
+
self.seq_to_frames = {}
|
491 |
+
for fi, item in enumerate(self.frame_annots):
|
492 |
+
if item["frame_annotation"].sequence_name in self.seq_to_frames:
|
493 |
+
self.seq_to_frames[item["frame_annotation"].sequence_name].append(fi)
|
494 |
+
else:
|
495 |
+
self.seq_to_frames[item["frame_annotation"].sequence_name] = [fi]
|
496 |
+
|
497 |
+
if self.stage != "test" or self.subset != "fewview_test":
|
498 |
+
count = 0
|
499 |
+
new_seq_to_frames = {}
|
500 |
+
for item in self.seq_to_frames:
|
501 |
+
if len(self.seq_to_frames[item]) > 10:
|
502 |
+
count += 1
|
503 |
+
new_seq_to_frames[item] = self.seq_to_frames[item]
|
504 |
+
self.seq_to_frames = new_seq_to_frames
|
505 |
+
|
506 |
+
self.seq_list = list(self.seq_to_frames.keys())
|
507 |
+
|
508 |
+
# @ REMOVE A FEW TRAINING SEQ THAT CAUSES BUG
|
509 |
+
remove_list = ["411_55952_107659", "376_42884_85882"]
|
510 |
+
for remove_idx in remove_list:
|
511 |
+
if remove_idx in self.seq_to_frames:
|
512 |
+
self.seq_list.remove(remove_idx)
|
513 |
+
print("removing", remove_idx)
|
514 |
+
|
515 |
+
print("total training seq", len(self.seq_to_frames))
|
516 |
+
print("data loading took", time.time() - start_time, "seconds")
|
517 |
+
|
518 |
+
self.all_category_list = list(CO3D_ALL_CATEGORIES)
|
519 |
+
self.all_category_list.sort()
|
520 |
+
self.cat_to_idx = {}
|
521 |
+
for ci, cname in enumerate(self.all_category_list):
|
522 |
+
self.cat_to_idx[cname] = ci
|
523 |
+
|
524 |
+
def __len__(self):
|
525 |
+
return len(self.seq_list)
|
526 |
+
|
527 |
+
def __getitem__(self, index):
|
528 |
+
seq_index = self.seq_list[index]
|
529 |
+
|
530 |
+
if self.subset == "fewview_test" and self.stage == "test":
|
531 |
+
batch_idx = torch.arange(len(self.seq_to_frames[seq_index]))
|
532 |
+
|
533 |
+
elif self.stage == "test":
|
534 |
+
batch_idx = (
|
535 |
+
torch.linspace(
|
536 |
+
0, len(self.seq_to_frames[seq_index]) - 1, self.sample_batch_size
|
537 |
+
)
|
538 |
+
.long()
|
539 |
+
.tolist()
|
540 |
+
)
|
541 |
+
else:
|
542 |
+
rand = torch.randperm(len(self.seq_to_frames[seq_index]))
|
543 |
+
batch_idx = rand[: min(len(rand), self.sample_batch_size)]
|
544 |
+
|
545 |
+
frame_data_list = []
|
546 |
+
idx_list = []
|
547 |
+
timestamp_list = []
|
548 |
+
for idx in batch_idx:
|
549 |
+
idx_list.append(self.seq_to_frames[seq_index][idx])
|
550 |
+
timestamp_list.append(
|
551 |
+
self.frame_annots[self.seq_to_frames[seq_index][idx]][
|
552 |
+
"frame_annotation"
|
553 |
+
].frame_timestamp
|
554 |
+
)
|
555 |
+
frame_data_list.append(
|
556 |
+
self._get_frame(int(self.seq_to_frames[seq_index][idx]))
|
557 |
+
)
|
558 |
+
|
559 |
+
time_order = torch.argsort(torch.tensor(timestamp_list))
|
560 |
+
frame_data_list = [frame_data_list[i] for i in time_order]
|
561 |
+
|
562 |
+
frame_data = FrameData.collate(frame_data_list)
|
563 |
+
image_size = torch.Tensor([self.image_height]).repeat(
|
564 |
+
frame_data.camera.R.shape[0], 2
|
565 |
+
)
|
566 |
+
frame_dict = {
|
567 |
+
"R": frame_data.camera.R,
|
568 |
+
"T": frame_data.camera.T,
|
569 |
+
"f": frame_data.camera.focal_length,
|
570 |
+
"c": frame_data.camera.principal_point,
|
571 |
+
"images": frame_data.image_rgb * frame_data.fg_probability
|
572 |
+
+ (1 - frame_data.fg_probability),
|
573 |
+
"valid_region": frame_data.mask_crop,
|
574 |
+
"bbox": frame_data.valid_region,
|
575 |
+
"image_size": image_size,
|
576 |
+
"frame_type": frame_data.frame_type,
|
577 |
+
"idx": seq_index,
|
578 |
+
"category": frame_data.category_one_hot,
|
579 |
+
}
|
580 |
+
if not self.masked:
|
581 |
+
frame_dict["images_full"] = frame_data.image_rgb
|
582 |
+
frame_dict["masks"] = frame_data.fg_probability
|
583 |
+
frame_dict["mask_crop"] = frame_data.mask_crop
|
584 |
+
|
585 |
+
cond_aug = np.exp(
|
586 |
+
np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
|
587 |
+
)
|
588 |
+
|
589 |
+
def _pad(input):
|
590 |
+
return torch.cat([input, torch.flip(input, dims=[0])], dim=0)[
|
591 |
+
: self.num_frames
|
592 |
+
]
|
593 |
+
|
594 |
+
if len(frame_dict["images"]) < self.num_frames:
|
595 |
+
for k in frame_dict:
|
596 |
+
if isinstance(frame_dict[k], torch.Tensor):
|
597 |
+
frame_dict[k] = _pad(frame_dict[k])
|
598 |
+
|
599 |
+
data = dict()
|
600 |
+
if "images_full" in frame_dict:
|
601 |
+
frames = frame_dict["images_full"] * 2 - 1
|
602 |
+
else:
|
603 |
+
frames = frame_dict["images"] * 2 - 1
|
604 |
+
data["frames"] = frames
|
605 |
+
cond = frames[0]
|
606 |
+
data["cond_frames_without_noise"] = cond
|
607 |
+
data["cond_aug"] = torch.as_tensor([cond_aug] * self.num_frames)
|
608 |
+
data["cond_frames"] = cond + cond_aug * torch.randn_like(cond)
|
609 |
+
data["fps_id"] = torch.as_tensor([self.fps_id] * self.num_frames)
|
610 |
+
data["motion_bucket_id"] = torch.as_tensor(
|
611 |
+
[self.motion_bucket_id] * self.num_frames
|
612 |
+
)
|
613 |
+
data["num_video_frames"] = self.num_frames
|
614 |
+
data["image_only_indicator"] = torch.as_tensor([0.0] * self.num_frames)
|
615 |
+
|
616 |
+
if self.load_pixelnerf:
|
617 |
+
data["pixelnerf_input"] = dict()
|
618 |
+
# Rs = frame_dict["R"].transpose(-1, -2)
|
619 |
+
# Ts = frame_dict["T"]
|
620 |
+
# Rs[:, :, 2] *= -1
|
621 |
+
# Rs[:, :, 0] *= -1
|
622 |
+
# Ts[:, 2] *= -1
|
623 |
+
# Ts[:, 0] *= -1
|
624 |
+
# c2ws = torch.zeros(Rs.shape[0], 4, 4)
|
625 |
+
# c2ws[:, :3, :3] = Rs
|
626 |
+
# c2ws[:, :3, 3] = Ts
|
627 |
+
# c2ws[:, 3, 3] = 1
|
628 |
+
# c2ws = c2ws.inverse()
|
629 |
+
# # c2ws[..., 0] *= -1
|
630 |
+
# # c2ws[..., 2] *= -1
|
631 |
+
# cx = frame_dict["c"][:, 0]
|
632 |
+
# cy = frame_dict["c"][:, 1]
|
633 |
+
# fx = frame_dict["f"][:, 0]
|
634 |
+
# fy = frame_dict["f"][:, 1]
|
635 |
+
# intrinsics = torch.zeros(cx.shape[0], 3, 3)
|
636 |
+
# intrinsics[:, 2, 2] = 1
|
637 |
+
# intrinsics[:, 0, 0] = fx
|
638 |
+
# intrinsics[:, 1, 1] = fy
|
639 |
+
# intrinsics[:, 0, 2] = cx
|
640 |
+
# intrinsics[:, 1, 2] = cy
|
641 |
+
|
642 |
+
scene_cameras = PerspectiveCameras(
|
643 |
+
R=frame_dict["R"],
|
644 |
+
T=frame_dict["T"],
|
645 |
+
focal_length=frame_dict["f"],
|
646 |
+
principal_point=frame_dict["c"],
|
647 |
+
image_size=frame_dict["image_size"],
|
648 |
+
)
|
649 |
+
R, T, intrinsics = opencv_from_cameras_projection(
|
650 |
+
scene_cameras, frame_dict["image_size"]
|
651 |
+
)
|
652 |
+
c2ws = torch.zeros(R.shape[0], 4, 4)
|
653 |
+
c2ws[:, :3, :3] = R
|
654 |
+
c2ws[:, :3, 3] = T
|
655 |
+
c2ws[:, 3, 3] = 1.0
|
656 |
+
c2ws = c2ws.inverse()
|
657 |
+
c2ws[..., 1:3] *= -1
|
658 |
+
intrinsics[:, :2] /= 256
|
659 |
+
|
660 |
+
cameras = torch.zeros(c2ws.shape[0], 25)
|
661 |
+
cameras[..., :16] = c2ws.reshape(-1, 16)
|
662 |
+
cameras[..., 16:] = intrinsics.reshape(-1, 9)
|
663 |
+
if self.scale_pose:
|
664 |
+
c2ws = cameras[..., :16].reshape(-1, 4, 4)
|
665 |
+
center = c2ws[:, :3, 3].mean(0)
|
666 |
+
radius = (c2ws[:, :3, 3] - center).norm(dim=-1).max()
|
667 |
+
scale = 1.5 / radius
|
668 |
+
c2ws[..., :3, 3] = (c2ws[..., :3, 3] - center) * scale
|
669 |
+
cameras[..., :16] = c2ws.reshape(-1, 16)
|
670 |
+
|
671 |
+
data["pixelnerf_input"]["frames"] = frames
|
672 |
+
data["pixelnerf_input"]["cameras"] = cameras
|
673 |
+
data["pixelnerf_input"]["rgb"] = (
|
674 |
+
F.interpolate(
|
675 |
+
frames,
|
676 |
+
(self.image_width // 8, self.image_height // 8),
|
677 |
+
mode="bilinear",
|
678 |
+
align_corners=False,
|
679 |
+
)
|
680 |
+
+ 1
|
681 |
+
) * 0.5
|
682 |
+
|
683 |
+
return data
|
684 |
+
# if self.return_frame_data_list:
|
685 |
+
# return (frame_dict, frame_data_list)
|
686 |
+
# return frame_dict
|
687 |
+
|
688 |
+
def collate_fn(self, batch):
|
689 |
+
# a hack to add source index and keep consistent within a batch
|
690 |
+
if self.max_n_cond > 1:
|
691 |
+
# TODO implement this
|
692 |
+
n_cond = np.random.randint(self.min_n_cond, self.max_n_cond + 1)
|
693 |
+
# debug
|
694 |
+
# source_index = [0]
|
695 |
+
if n_cond > 1:
|
696 |
+
for b in batch:
|
697 |
+
source_index = [0] + np.random.choice(
|
698 |
+
np.arange(1, self.num_frames),
|
699 |
+
self.max_n_cond - 1,
|
700 |
+
replace=False,
|
701 |
+
).tolist()
|
702 |
+
b["pixelnerf_input"]["source_index"] = torch.as_tensor(source_index)
|
703 |
+
b["pixelnerf_input"]["n_cond"] = n_cond
|
704 |
+
b["pixelnerf_input"]["source_images"] = b["frames"][source_index]
|
705 |
+
b["pixelnerf_input"]["source_cameras"] = b["pixelnerf_input"][
|
706 |
+
"cameras"
|
707 |
+
][source_index]
|
708 |
+
|
709 |
+
if self.cond_on_multi:
|
710 |
+
b["cond_frames_without_noise"] = b["frames"][source_index]
|
711 |
+
|
712 |
+
ret = video_collate_fn(batch)
|
713 |
+
|
714 |
+
if self.cond_on_multi:
|
715 |
+
ret["cond_frames_without_noise"] = rearrange(
|
716 |
+
ret["cond_frames_without_noise"], "b t ... -> (b t) ..."
|
717 |
+
)
|
718 |
+
|
719 |
+
return ret
|
720 |
+
|
721 |
+
def _get_frame(self, index):
|
722 |
+
# if index >= len(self.frame_annots):
|
723 |
+
# raise IndexError(f"index {index} out of range {len(self.frame_annots)}")
|
724 |
+
|
725 |
+
entry = self.frame_annots[index]["frame_annotation"]
|
726 |
+
# pyre-ignore[16]
|
727 |
+
point_cloud = self.seq_annots[entry.sequence_name].point_cloud
|
728 |
+
frame_data = FrameData(
|
729 |
+
frame_number=_safe_as_tensor(entry.frame_number, torch.long),
|
730 |
+
frame_timestamp=_safe_as_tensor(entry.frame_timestamp, torch.float),
|
731 |
+
sequence_name=entry.sequence_name,
|
732 |
+
sequence_category=self.seq_annots[entry.sequence_name].category,
|
733 |
+
camera_quality_score=_safe_as_tensor(
|
734 |
+
self.seq_annots[entry.sequence_name].viewpoint_quality_score,
|
735 |
+
torch.float,
|
736 |
+
),
|
737 |
+
point_cloud_quality_score=_safe_as_tensor(
|
738 |
+
point_cloud.quality_score, torch.float
|
739 |
+
)
|
740 |
+
if point_cloud is not None
|
741 |
+
else None,
|
742 |
+
)
|
743 |
+
|
744 |
+
# The rest of the fields are optional
|
745 |
+
frame_data.frame_type = self._get_frame_type(self.frame_annots[index])
|
746 |
+
|
747 |
+
(
|
748 |
+
frame_data.fg_probability,
|
749 |
+
frame_data.mask_path,
|
750 |
+
frame_data.bbox_xywh,
|
751 |
+
clamp_bbox_xyxy,
|
752 |
+
frame_data.crop_bbox_xywh,
|
753 |
+
) = self._load_crop_fg_probability(entry)
|
754 |
+
|
755 |
+
scale = 1.0
|
756 |
+
if self.load_images and entry.image is not None:
|
757 |
+
# original image size
|
758 |
+
frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long)
|
759 |
+
|
760 |
+
(
|
761 |
+
frame_data.image_rgb,
|
762 |
+
frame_data.image_path,
|
763 |
+
frame_data.mask_crop,
|
764 |
+
scale,
|
765 |
+
) = self._load_crop_images(
|
766 |
+
entry, frame_data.fg_probability, clamp_bbox_xyxy
|
767 |
+
)
|
768 |
+
# print(frame_data.fg_probability.sum())
|
769 |
+
# print('scale', scale)
|
770 |
+
|
771 |
+
#! INSERT
|
772 |
+
if self.deprecated_val_region:
|
773 |
+
# print(frame_data.crop_bbox_xywh)
|
774 |
+
valid_bbox = _bbox_xywh_to_xyxy(frame_data.crop_bbox_xywh).float()
|
775 |
+
# print(valid_bbox, frame_data.image_size_hw)
|
776 |
+
valid_bbox[0] = torch.clip(
|
777 |
+
(
|
778 |
+
valid_bbox[0]
|
779 |
+
- torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor")
|
780 |
+
)
|
781 |
+
/ torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor"),
|
782 |
+
-1.0,
|
783 |
+
1.0,
|
784 |
+
)
|
785 |
+
valid_bbox[1] = torch.clip(
|
786 |
+
(
|
787 |
+
valid_bbox[1]
|
788 |
+
- torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor")
|
789 |
+
)
|
790 |
+
/ torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor"),
|
791 |
+
-1.0,
|
792 |
+
1.0,
|
793 |
+
)
|
794 |
+
valid_bbox[2] = torch.clip(
|
795 |
+
(
|
796 |
+
valid_bbox[2]
|
797 |
+
- torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor")
|
798 |
+
)
|
799 |
+
/ torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor"),
|
800 |
+
-1.0,
|
801 |
+
1.0,
|
802 |
+
)
|
803 |
+
valid_bbox[3] = torch.clip(
|
804 |
+
(
|
805 |
+
valid_bbox[3]
|
806 |
+
- torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor")
|
807 |
+
)
|
808 |
+
/ torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor"),
|
809 |
+
-1.0,
|
810 |
+
1.0,
|
811 |
+
)
|
812 |
+
# print(valid_bbox)
|
813 |
+
frame_data.valid_region = valid_bbox
|
814 |
+
else:
|
815 |
+
#! UPDATED VALID BBOX
|
816 |
+
if self.stage == "train":
|
817 |
+
assert self.image_height == 256 and self.image_width == 256
|
818 |
+
valid = torch.nonzero(frame_data.mask_crop[0])
|
819 |
+
min_y = valid[:, 0].min()
|
820 |
+
min_x = valid[:, 1].min()
|
821 |
+
max_y = valid[:, 0].max()
|
822 |
+
max_x = valid[:, 1].max()
|
823 |
+
valid_bbox = torch.tensor(
|
824 |
+
[min_y, min_x, max_y, max_x], device=frame_data.image_rgb.device
|
825 |
+
).unsqueeze(0)
|
826 |
+
valid_bbox = torch.clip(
|
827 |
+
(valid_bbox - (256 // 2)) / (256 // 2), -1.0, 1.0
|
828 |
+
)
|
829 |
+
frame_data.valid_region = valid_bbox[0]
|
830 |
+
else:
|
831 |
+
valid = torch.nonzero(frame_data.mask_crop[0])
|
832 |
+
min_y = valid[:, 0].min()
|
833 |
+
min_x = valid[:, 1].min()
|
834 |
+
max_y = valid[:, 0].max()
|
835 |
+
max_x = valid[:, 1].max()
|
836 |
+
valid_bbox = torch.tensor(
|
837 |
+
[min_y, min_x, max_y, max_x], device=frame_data.image_rgb.device
|
838 |
+
).unsqueeze(0)
|
839 |
+
valid_bbox = torch.clip(
|
840 |
+
(valid_bbox - (self.image_height // 2)) / (self.image_height // 2),
|
841 |
+
-1.0,
|
842 |
+
1.0,
|
843 |
+
)
|
844 |
+
frame_data.valid_region = valid_bbox[0]
|
845 |
+
|
846 |
+
#! SET CLASS ONEHOT
|
847 |
+
frame_data.category_one_hot = torch.zeros(
|
848 |
+
(len(self.all_category_list)), device=frame_data.image_rgb.device
|
849 |
+
)
|
850 |
+
frame_data.category_one_hot[self.cat_to_idx[frame_data.sequence_category]] = 1
|
851 |
+
|
852 |
+
if self.load_depths and entry.depth is not None:
|
853 |
+
(
|
854 |
+
frame_data.depth_map,
|
855 |
+
frame_data.depth_path,
|
856 |
+
frame_data.depth_mask,
|
857 |
+
) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability)
|
858 |
+
|
859 |
+
if entry.viewpoint is not None:
|
860 |
+
frame_data.camera = self._get_pytorch3d_camera(
|
861 |
+
entry,
|
862 |
+
scale,
|
863 |
+
clamp_bbox_xyxy,
|
864 |
+
)
|
865 |
+
|
866 |
+
if self.load_point_clouds and point_cloud is not None:
|
867 |
+
frame_data.sequence_point_cloud_path = pcl_path = os.path.join(
|
868 |
+
self.dataset_root, point_cloud.path
|
869 |
+
)
|
870 |
+
frame_data.sequence_point_cloud = _load_pointcloud(
|
871 |
+
self._local_path(pcl_path), max_points=self.max_points
|
872 |
+
)
|
873 |
+
|
874 |
+
# for key in frame_data:
|
875 |
+
# if frame_data[key] == None:
|
876 |
+
# print(key)
|
877 |
+
return frame_data
|
878 |
+
|
879 |
+
def _extract_and_set_eval_batches(self):
|
880 |
+
"""
|
881 |
+
Sets eval_batches based on input eval_batch_index.
|
882 |
+
"""
|
883 |
+
if self.eval_batch_index is not None:
|
884 |
+
if self.eval_batches is not None:
|
885 |
+
raise ValueError(
|
886 |
+
"Cannot define both eval_batch_index and eval_batches."
|
887 |
+
)
|
888 |
+
self.eval_batches = self.seq_frame_index_to_dataset_index(
|
889 |
+
self.eval_batch_index
|
890 |
+
)
|
891 |
+
|
892 |
+
def _load_crop_fg_probability(
|
893 |
+
self, entry: types.FrameAnnotation
|
894 |
+
) -> Tuple[
|
895 |
+
Optional[torch.Tensor],
|
896 |
+
Optional[str],
|
897 |
+
Optional[torch.Tensor],
|
898 |
+
Optional[torch.Tensor],
|
899 |
+
Optional[torch.Tensor],
|
900 |
+
]:
|
901 |
+
fg_probability = None
|
902 |
+
full_path = None
|
903 |
+
bbox_xywh = None
|
904 |
+
clamp_bbox_xyxy = None
|
905 |
+
crop_box_xywh = None
|
906 |
+
|
907 |
+
if (self.load_masks or self.box_crop) and entry.mask is not None:
|
908 |
+
full_path = os.path.join(self.dataset_root, entry.mask.path)
|
909 |
+
mask = _load_mask(self._local_path(full_path))
|
910 |
+
|
911 |
+
if mask.shape[-2:] != entry.image.size:
|
912 |
+
raise ValueError(
|
913 |
+
f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!"
|
914 |
+
)
|
915 |
+
|
916 |
+
bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr))
|
917 |
+
|
918 |
+
if self.box_crop:
|
919 |
+
clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round(
|
920 |
+
_get_clamp_bbox(
|
921 |
+
bbox_xywh,
|
922 |
+
image_path=entry.image.path,
|
923 |
+
box_crop_context=self.box_crop_context,
|
924 |
+
),
|
925 |
+
image_size_hw=tuple(mask.shape[-2:]),
|
926 |
+
)
|
927 |
+
crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy)
|
928 |
+
|
929 |
+
mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path)
|
930 |
+
|
931 |
+
fg_probability, _, _ = self._resize_image(mask, mode="nearest")
|
932 |
+
|
933 |
+
return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh
|
934 |
+
|
935 |
+
def _load_crop_images(
|
936 |
+
self,
|
937 |
+
entry: types.FrameAnnotation,
|
938 |
+
fg_probability: Optional[torch.Tensor],
|
939 |
+
clamp_bbox_xyxy: Optional[torch.Tensor],
|
940 |
+
) -> Tuple[torch.Tensor, str, torch.Tensor, float]:
|
941 |
+
assert self.dataset_root is not None and entry.image is not None
|
942 |
+
path = os.path.join(self.dataset_root, entry.image.path)
|
943 |
+
image_rgb = _load_image(self._local_path(path))
|
944 |
+
|
945 |
+
if image_rgb.shape[-2:] != entry.image.size:
|
946 |
+
raise ValueError(
|
947 |
+
f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
|
948 |
+
)
|
949 |
+
|
950 |
+
if self.box_crop:
|
951 |
+
assert clamp_bbox_xyxy is not None
|
952 |
+
image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path)
|
953 |
+
|
954 |
+
image_rgb, scale, mask_crop = self._resize_image(image_rgb)
|
955 |
+
|
956 |
+
if self.mask_images:
|
957 |
+
assert fg_probability is not None
|
958 |
+
image_rgb *= fg_probability
|
959 |
+
|
960 |
+
return image_rgb, path, mask_crop, scale
|
961 |
+
|
962 |
+
def _load_mask_depth(
|
963 |
+
self,
|
964 |
+
entry: types.FrameAnnotation,
|
965 |
+
clamp_bbox_xyxy: Optional[torch.Tensor],
|
966 |
+
fg_probability: Optional[torch.Tensor],
|
967 |
+
) -> Tuple[torch.Tensor, str, torch.Tensor]:
|
968 |
+
entry_depth = entry.depth
|
969 |
+
assert entry_depth is not None
|
970 |
+
path = os.path.join(self.dataset_root, entry_depth.path)
|
971 |
+
depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment)
|
972 |
+
|
973 |
+
if self.box_crop:
|
974 |
+
assert clamp_bbox_xyxy is not None
|
975 |
+
depth_bbox_xyxy = _rescale_bbox(
|
976 |
+
clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:]
|
977 |
+
)
|
978 |
+
depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path)
|
979 |
+
|
980 |
+
depth_map, _, _ = self._resize_image(depth_map, mode="nearest")
|
981 |
+
|
982 |
+
if self.mask_depths:
|
983 |
+
assert fg_probability is not None
|
984 |
+
depth_map *= fg_probability
|
985 |
+
|
986 |
+
if self.load_depth_masks:
|
987 |
+
assert entry_depth.mask_path is not None
|
988 |
+
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
|
989 |
+
depth_mask = _load_depth_mask(self._local_path(mask_path))
|
990 |
+
|
991 |
+
if self.box_crop:
|
992 |
+
assert clamp_bbox_xyxy is not None
|
993 |
+
depth_mask_bbox_xyxy = _rescale_bbox(
|
994 |
+
clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:]
|
995 |
+
)
|
996 |
+
depth_mask = _crop_around_box(
|
997 |
+
depth_mask, depth_mask_bbox_xyxy, mask_path
|
998 |
+
)
|
999 |
+
|
1000 |
+
depth_mask, _, _ = self._resize_image(depth_mask, mode="nearest")
|
1001 |
+
else:
|
1002 |
+
depth_mask = torch.ones_like(depth_map)
|
1003 |
+
|
1004 |
+
return depth_map, path, depth_mask
|
1005 |
+
|
1006 |
+
def _get_pytorch3d_camera(
|
1007 |
+
self,
|
1008 |
+
entry: types.FrameAnnotation,
|
1009 |
+
scale: float,
|
1010 |
+
clamp_bbox_xyxy: Optional[torch.Tensor],
|
1011 |
+
) -> PerspectiveCameras:
|
1012 |
+
entry_viewpoint = entry.viewpoint
|
1013 |
+
assert entry_viewpoint is not None
|
1014 |
+
# principal point and focal length
|
1015 |
+
principal_point = torch.tensor(
|
1016 |
+
entry_viewpoint.principal_point, dtype=torch.float
|
1017 |
+
)
|
1018 |
+
focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)
|
1019 |
+
|
1020 |
+
half_image_size_wh_orig = (
|
1021 |
+
torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0
|
1022 |
+
)
|
1023 |
+
|
1024 |
+
# first, we convert from the dataset's NDC convention to pixels
|
1025 |
+
format = entry_viewpoint.intrinsics_format
|
1026 |
+
if format.lower() == "ndc_norm_image_bounds":
|
1027 |
+
# this is e.g. currently used in CO3D for storing intrinsics
|
1028 |
+
rescale = half_image_size_wh_orig
|
1029 |
+
elif format.lower() == "ndc_isotropic":
|
1030 |
+
rescale = half_image_size_wh_orig.min()
|
1031 |
+
else:
|
1032 |
+
raise ValueError(f"Unknown intrinsics format: {format}")
|
1033 |
+
|
1034 |
+
# principal point and focal length in pixels
|
1035 |
+
principal_point_px = half_image_size_wh_orig - principal_point * rescale
|
1036 |
+
focal_length_px = focal_length * rescale
|
1037 |
+
if self.box_crop:
|
1038 |
+
assert clamp_bbox_xyxy is not None
|
1039 |
+
principal_point_px -= clamp_bbox_xyxy[:2]
|
1040 |
+
|
1041 |
+
# now, convert from pixels to PyTorch3D v0.5+ NDC convention
|
1042 |
+
if self.image_height is None or self.image_width is None:
|
1043 |
+
out_size = list(reversed(entry.image.size))
|
1044 |
+
else:
|
1045 |
+
out_size = [self.image_width, self.image_height]
|
1046 |
+
|
1047 |
+
half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0
|
1048 |
+
half_min_image_size_output = half_image_size_output.min()
|
1049 |
+
|
1050 |
+
# rescaled principal point and focal length in ndc
|
1051 |
+
principal_point = (
|
1052 |
+
half_image_size_output - principal_point_px * scale
|
1053 |
+
) / half_min_image_size_output
|
1054 |
+
focal_length = focal_length_px * scale / half_min_image_size_output
|
1055 |
+
|
1056 |
+
return PerspectiveCameras(
|
1057 |
+
focal_length=focal_length[None],
|
1058 |
+
principal_point=principal_point[None],
|
1059 |
+
R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],
|
1060 |
+
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
|
1061 |
+
)
|
1062 |
+
|
1063 |
+
def _load_frames(self) -> None:
|
1064 |
+
self.frame_annots = [
|
1065 |
+
FrameAnnotsEntry(frame_annotation=a, subset=None)
|
1066 |
+
for a in self.category_frame_annotations
|
1067 |
+
]
|
1068 |
+
|
1069 |
+
def _load_sequences(self) -> None:
|
1070 |
+
self.seq_annots = {
|
1071 |
+
entry.sequence_name: entry for entry in self.category_sequence_annotations
|
1072 |
+
}
|
1073 |
+
|
1074 |
+
def _load_subset_lists(self) -> None:
|
1075 |
+
logger.info(f"Loading Co3D subset lists from {self.subset_lists_file}.")
|
1076 |
+
if not self.subset_lists_file:
|
1077 |
+
return
|
1078 |
+
|
1079 |
+
frame_path_to_subset = {}
|
1080 |
+
|
1081 |
+
for subset_list_file in self.subset_lists_file:
|
1082 |
+
with open(self._local_path(subset_list_file), "r") as f:
|
1083 |
+
subset_to_seq_frame = json.load(f)
|
1084 |
+
|
1085 |
+
#! PRINT SUBSET_LIST STATS
|
1086 |
+
# if len(self.subset_lists_file) == 1:
|
1087 |
+
# print('train frames', len(subset_to_seq_frame['train']))
|
1088 |
+
# print('val frames', len(subset_to_seq_frame['val']))
|
1089 |
+
# print('test frames', len(subset_to_seq_frame['test']))
|
1090 |
+
|
1091 |
+
for set_ in subset_to_seq_frame:
|
1092 |
+
for _, _, path in subset_to_seq_frame[set_]:
|
1093 |
+
if path in frame_path_to_subset:
|
1094 |
+
frame_path_to_subset[path].add(set_)
|
1095 |
+
else:
|
1096 |
+
frame_path_to_subset[path] = {set_}
|
1097 |
+
|
1098 |
+
# pyre-ignore[16]
|
1099 |
+
for frame in self.frame_annots:
|
1100 |
+
frame["subset"] = frame_path_to_subset.get(
|
1101 |
+
frame["frame_annotation"].image.path, None
|
1102 |
+
)
|
1103 |
+
|
1104 |
+
if frame["subset"] is None:
|
1105 |
+
continue
|
1106 |
+
warnings.warn(
|
1107 |
+
"Subset lists are given but don't include "
|
1108 |
+
+ frame["frame_annotation"].image.path
|
1109 |
+
)
|
1110 |
+
|
1111 |
+
def _sort_frames(self) -> None:
|
1112 |
+
# Sort frames to have them grouped by sequence, ordered by timestamp
|
1113 |
+
# pyre-ignore[16]
|
1114 |
+
self.frame_annots = sorted(
|
1115 |
+
self.frame_annots,
|
1116 |
+
key=lambda f: (
|
1117 |
+
f["frame_annotation"].sequence_name,
|
1118 |
+
f["frame_annotation"].frame_timestamp or 0,
|
1119 |
+
),
|
1120 |
+
)
|
1121 |
+
|
1122 |
+
def _filter_db(self) -> None:
|
1123 |
+
if self.remove_empty_masks:
|
1124 |
+
logger.info("Removing images with empty masks.")
|
1125 |
+
# pyre-ignore[16]
|
1126 |
+
old_len = len(self.frame_annots)
|
1127 |
+
|
1128 |
+
msg = "remove_empty_masks needs every MaskAnnotation.mass to be set."
|
1129 |
+
|
1130 |
+
def positive_mass(frame_annot: types.FrameAnnotation) -> bool:
|
1131 |
+
mask = frame_annot.mask
|
1132 |
+
if mask is None:
|
1133 |
+
return False
|
1134 |
+
if mask.mass is None:
|
1135 |
+
raise ValueError(msg)
|
1136 |
+
return mask.mass > 1
|
1137 |
+
|
1138 |
+
self.frame_annots = [
|
1139 |
+
frame
|
1140 |
+
for frame in self.frame_annots
|
1141 |
+
if positive_mass(frame["frame_annotation"])
|
1142 |
+
]
|
1143 |
+
logger.info("... filtered %d -> %d" % (old_len, len(self.frame_annots)))
|
1144 |
+
|
1145 |
+
# this has to be called after joining with categories!!
|
1146 |
+
subsets = self.subsets
|
1147 |
+
if subsets:
|
1148 |
+
if not self.subset_lists_file:
|
1149 |
+
raise ValueError(
|
1150 |
+
"Subset filter is on but subset_lists_file was not given"
|
1151 |
+
)
|
1152 |
+
|
1153 |
+
logger.info(f"Limiting Co3D dataset to the '{subsets}' subsets.")
|
1154 |
+
|
1155 |
+
# truncate the list of subsets to the valid one
|
1156 |
+
self.frame_annots = [
|
1157 |
+
entry
|
1158 |
+
for entry in self.frame_annots
|
1159 |
+
if (entry["subset"] is not None and self.stage in entry["subset"])
|
1160 |
+
]
|
1161 |
+
|
1162 |
+
if len(self.frame_annots) == 0:
|
1163 |
+
raise ValueError(f"There are no frames in the '{subsets}' subsets!")
|
1164 |
+
|
1165 |
+
self._invalidate_indexes(filter_seq_annots=True)
|
1166 |
+
|
1167 |
+
if len(self.limit_category_to) > 0:
|
1168 |
+
logger.info(f"Limiting dataset to categories: {self.limit_category_to}")
|
1169 |
+
# pyre-ignore[16]
|
1170 |
+
self.seq_annots = {
|
1171 |
+
name: entry
|
1172 |
+
for name, entry in self.seq_annots.items()
|
1173 |
+
if entry.category in self.limit_category_to
|
1174 |
+
}
|
1175 |
+
|
1176 |
+
# sequence filters
|
1177 |
+
for prefix in ("pick", "exclude"):
|
1178 |
+
orig_len = len(self.seq_annots)
|
1179 |
+
attr = f"{prefix}_sequence"
|
1180 |
+
arr = getattr(self, attr)
|
1181 |
+
if len(arr) > 0:
|
1182 |
+
logger.info(f"{attr}: {str(arr)}")
|
1183 |
+
self.seq_annots = {
|
1184 |
+
name: entry
|
1185 |
+
for name, entry in self.seq_annots.items()
|
1186 |
+
if (name in arr) == (prefix == "pick")
|
1187 |
+
}
|
1188 |
+
logger.info("... filtered %d -> %d" % (orig_len, len(self.seq_annots)))
|
1189 |
+
|
1190 |
+
if self.limit_sequences_to > 0:
|
1191 |
+
self.seq_annots = dict(
|
1192 |
+
islice(self.seq_annots.items(), self.limit_sequences_to)
|
1193 |
+
)
|
1194 |
+
|
1195 |
+
# retain only frames from retained sequences
|
1196 |
+
self.frame_annots = [
|
1197 |
+
f
|
1198 |
+
for f in self.frame_annots
|
1199 |
+
if f["frame_annotation"].sequence_name in self.seq_annots
|
1200 |
+
]
|
1201 |
+
|
1202 |
+
self._invalidate_indexes()
|
1203 |
+
|
1204 |
+
if self.n_frames_per_sequence > 0:
|
1205 |
+
logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.")
|
1206 |
+
keep_idx = []
|
1207 |
+
# pyre-ignore[16]
|
1208 |
+
for seq, seq_indices in self._seq_to_idx.items():
|
1209 |
+
# infer the seed from the sequence name, this is reproducible
|
1210 |
+
# and makes the selection differ for different sequences
|
1211 |
+
seed = _seq_name_to_seed(seq) + self.seed
|
1212 |
+
seq_idx_shuffled = random.Random(seed).sample(
|
1213 |
+
sorted(seq_indices), len(seq_indices)
|
1214 |
+
)
|
1215 |
+
keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence])
|
1216 |
+
|
1217 |
+
logger.info(
|
1218 |
+
"... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx))
|
1219 |
+
)
|
1220 |
+
self.frame_annots = [self.frame_annots[i] for i in keep_idx]
|
1221 |
+
self._invalidate_indexes(filter_seq_annots=False)
|
1222 |
+
# sequences are not decimated, so self.seq_annots is valid
|
1223 |
+
|
1224 |
+
if self.limit_to > 0 and self.limit_to < len(self.frame_annots):
|
1225 |
+
logger.info(
|
1226 |
+
"limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to)
|
1227 |
+
)
|
1228 |
+
self.frame_annots = self.frame_annots[: self.limit_to]
|
1229 |
+
self._invalidate_indexes(filter_seq_annots=True)
|
1230 |
+
|
1231 |
+
def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None:
|
1232 |
+
# update _seq_to_idx and filter seq_meta according to frame_annots change
|
1233 |
+
# if filter_seq_annots, also uldates seq_annots based on the changed _seq_to_idx
|
1234 |
+
self._invalidate_seq_to_idx()
|
1235 |
+
|
1236 |
+
if filter_seq_annots:
|
1237 |
+
# pyre-ignore[16]
|
1238 |
+
self.seq_annots = {
|
1239 |
+
k: v
|
1240 |
+
for k, v in self.seq_annots.items()
|
1241 |
+
# pyre-ignore[16]
|
1242 |
+
if k in self._seq_to_idx
|
1243 |
+
}
|
1244 |
+
|
1245 |
+
def _invalidate_seq_to_idx(self) -> None:
|
1246 |
+
seq_to_idx = defaultdict(list)
|
1247 |
+
# pyre-ignore[16]
|
1248 |
+
for idx, entry in enumerate(self.frame_annots):
|
1249 |
+
seq_to_idx[entry["frame_annotation"].sequence_name].append(idx)
|
1250 |
+
# pyre-ignore[16]
|
1251 |
+
self._seq_to_idx = seq_to_idx
|
1252 |
+
|
1253 |
+
def _resize_image(
|
1254 |
+
self, image, mode="bilinear"
|
1255 |
+
) -> Tuple[torch.Tensor, float, torch.Tensor]:
|
1256 |
+
image_height, image_width = self.image_height, self.image_width
|
1257 |
+
if image_height is None or image_width is None:
|
1258 |
+
# skip the resizing
|
1259 |
+
imre_ = torch.from_numpy(image)
|
1260 |
+
return imre_, 1.0, torch.ones_like(imre_[:1])
|
1261 |
+
# takes numpy array, returns pytorch tensor
|
1262 |
+
minscale = min(
|
1263 |
+
image_height / image.shape[-2],
|
1264 |
+
image_width / image.shape[-1],
|
1265 |
+
)
|
1266 |
+
imre = torch.nn.functional.interpolate(
|
1267 |
+
torch.from_numpy(image)[None],
|
1268 |
+
scale_factor=minscale,
|
1269 |
+
mode=mode,
|
1270 |
+
align_corners=False if mode == "bilinear" else None,
|
1271 |
+
recompute_scale_factor=True,
|
1272 |
+
)[0]
|
1273 |
+
# pyre-fixme[19]: Expected 1 positional argument.
|
1274 |
+
imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width)
|
1275 |
+
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
|
1276 |
+
# pyre-fixme[6]: For 2nd param expected `int` but got `Optional[int]`.
|
1277 |
+
# pyre-fixme[6]: For 3rd param expected `int` but got `Optional[int]`.
|
1278 |
+
mask = torch.zeros(1, self.image_height, self.image_width)
|
1279 |
+
mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0
|
1280 |
+
return imre_, minscale, mask
|
1281 |
+
|
1282 |
+
def _local_path(self, path: str) -> str:
|
1283 |
+
if self.path_manager is None:
|
1284 |
+
return path
|
1285 |
+
return self.path_manager.get_local_path(path)
|
1286 |
+
|
1287 |
+
def get_frame_numbers_and_timestamps(
|
1288 |
+
self, idxs: Sequence[int]
|
1289 |
+
) -> List[Tuple[int, float]]:
|
1290 |
+
out: List[Tuple[int, float]] = []
|
1291 |
+
for idx in idxs:
|
1292 |
+
# pyre-ignore[16]
|
1293 |
+
frame_annotation = self.frame_annots[idx]["frame_annotation"]
|
1294 |
+
out.append(
|
1295 |
+
(frame_annotation.frame_number, frame_annotation.frame_timestamp)
|
1296 |
+
)
|
1297 |
+
return out
|
1298 |
+
|
1299 |
+
def get_eval_batches(self) -> Optional[List[List[int]]]:
|
1300 |
+
return self.eval_batches
|
1301 |
+
|
1302 |
+
def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]:
|
1303 |
+
return entry["frame_annotation"].meta["frame_type"]
|
1304 |
+
|
1305 |
+
|
1306 |
+
class CO3DDataset(LightningDataModule):
|
1307 |
+
def __init__(
|
1308 |
+
self,
|
1309 |
+
root_dir,
|
1310 |
+
batch_size=2,
|
1311 |
+
shuffle=True,
|
1312 |
+
num_workers=10,
|
1313 |
+
prefetch_factor=2,
|
1314 |
+
category="hydrant",
|
1315 |
+
**kwargs,
|
1316 |
+
):
|
1317 |
+
super().__init__()
|
1318 |
+
|
1319 |
+
self.batch_size = batch_size
|
1320 |
+
self.num_workers = num_workers
|
1321 |
+
self.prefetch_factor = prefetch_factor
|
1322 |
+
self.shuffle = shuffle
|
1323 |
+
|
1324 |
+
self.train_dataset = CO3Dv2Wrapper(
|
1325 |
+
root_dir=root_dir,
|
1326 |
+
stage="train",
|
1327 |
+
category=category,
|
1328 |
+
**kwargs,
|
1329 |
+
)
|
1330 |
+
|
1331 |
+
self.test_dataset = CO3Dv2Wrapper(
|
1332 |
+
root_dir=root_dir,
|
1333 |
+
stage="test",
|
1334 |
+
subset="fewview_dev",
|
1335 |
+
category=category,
|
1336 |
+
**kwargs,
|
1337 |
+
)
|
1338 |
+
|
1339 |
+
def train_dataloader(self):
|
1340 |
+
return DataLoader(
|
1341 |
+
self.train_dataset,
|
1342 |
+
batch_size=self.batch_size,
|
1343 |
+
shuffle=self.shuffle,
|
1344 |
+
num_workers=self.num_workers,
|
1345 |
+
prefetch_factor=self.prefetch_factor,
|
1346 |
+
collate_fn=self.train_dataset.collate_fn,
|
1347 |
+
)
|
1348 |
+
|
1349 |
+
def test_dataloader(self):
|
1350 |
+
return DataLoader(
|
1351 |
+
self.test_dataset,
|
1352 |
+
batch_size=self.batch_size,
|
1353 |
+
shuffle=self.shuffle,
|
1354 |
+
num_workers=self.num_workers,
|
1355 |
+
prefetch_factor=self.prefetch_factor,
|
1356 |
+
collate_fn=self.test_dataset.collate_fn,
|
1357 |
+
)
|
1358 |
+
|
1359 |
+
def val_dataloader(self):
|
1360 |
+
return DataLoader(
|
1361 |
+
self.test_dataset,
|
1362 |
+
batch_size=self.batch_size,
|
1363 |
+
shuffle=self.shuffle,
|
1364 |
+
num_workers=self.num_workers,
|
1365 |
+
prefetch_factor=self.prefetch_factor,
|
1366 |
+
collate_fn=video_collate_fn,
|
1367 |
+
)
|
sgm/data/colmap.py
ADDED
@@ -0,0 +1,605 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, ETH Zurich and UNC Chapel Hill.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# Redistribution and use in source and binary forms, with or without
|
5 |
+
# modification, are permitted provided that the following conditions are met:
|
6 |
+
#
|
7 |
+
# * Redistributions of source code must retain the above copyright
|
8 |
+
# notice, this list of conditions and the following disclaimer.
|
9 |
+
#
|
10 |
+
# * Redistributions in binary form must reproduce the above copyright
|
11 |
+
# notice, this list of conditions and the following disclaimer in the
|
12 |
+
# documentation and/or other materials provided with the distribution.
|
13 |
+
#
|
14 |
+
# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
|
15 |
+
# its contributors may be used to endorse or promote products derived
|
16 |
+
# from this software without specific prior written permission.
|
17 |
+
#
|
18 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
19 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
20 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
21 |
+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
|
22 |
+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
23 |
+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
24 |
+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
25 |
+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
26 |
+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
27 |
+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
28 |
+
# POSSIBILITY OF SUCH DAMAGE.
|
29 |
+
|
30 |
+
|
31 |
+
import os
|
32 |
+
import collections
|
33 |
+
import numpy as np
|
34 |
+
import struct
|
35 |
+
import argparse
|
36 |
+
|
37 |
+
|
38 |
+
CameraModel = collections.namedtuple(
|
39 |
+
"CameraModel", ["model_id", "model_name", "num_params"]
|
40 |
+
)
|
41 |
+
Camera = collections.namedtuple(
|
42 |
+
"Camera", ["id", "model", "width", "height", "params"]
|
43 |
+
)
|
44 |
+
BaseImage = collections.namedtuple(
|
45 |
+
"Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]
|
46 |
+
)
|
47 |
+
Point3D = collections.namedtuple(
|
48 |
+
"Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
class Image(BaseImage):
|
53 |
+
def qvec2rotmat(self):
|
54 |
+
return qvec2rotmat(self.qvec)
|
55 |
+
|
56 |
+
|
57 |
+
CAMERA_MODELS = {
|
58 |
+
CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
|
59 |
+
CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
|
60 |
+
CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
|
61 |
+
CameraModel(model_id=3, model_name="RADIAL", num_params=5),
|
62 |
+
CameraModel(model_id=4, model_name="OPENCV", num_params=8),
|
63 |
+
CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
|
64 |
+
CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
|
65 |
+
CameraModel(model_id=7, model_name="FOV", num_params=5),
|
66 |
+
CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
|
67 |
+
CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
|
68 |
+
CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12),
|
69 |
+
}
|
70 |
+
CAMERA_MODEL_IDS = dict(
|
71 |
+
[(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS]
|
72 |
+
)
|
73 |
+
CAMERA_MODEL_NAMES = dict(
|
74 |
+
[(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS]
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
|
79 |
+
"""Read and unpack the next bytes from a binary file.
|
80 |
+
:param fid:
|
81 |
+
:param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
|
82 |
+
:param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
|
83 |
+
:param endian_character: Any of {@, =, <, >, !}
|
84 |
+
:return: Tuple of read and unpacked values.
|
85 |
+
"""
|
86 |
+
data = fid.read(num_bytes)
|
87 |
+
return struct.unpack(endian_character + format_char_sequence, data)
|
88 |
+
|
89 |
+
|
90 |
+
def write_next_bytes(fid, data, format_char_sequence, endian_character="<"):
|
91 |
+
"""pack and write to a binary file.
|
92 |
+
:param fid:
|
93 |
+
:param data: data to send, if multiple elements are sent at the same time,
|
94 |
+
they should be encapsuled either in a list or a tuple
|
95 |
+
:param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
|
96 |
+
should be the same length as the data list or tuple
|
97 |
+
:param endian_character: Any of {@, =, <, >, !}
|
98 |
+
"""
|
99 |
+
if isinstance(data, (list, tuple)):
|
100 |
+
bytes = struct.pack(endian_character + format_char_sequence, *data)
|
101 |
+
else:
|
102 |
+
bytes = struct.pack(endian_character + format_char_sequence, data)
|
103 |
+
fid.write(bytes)
|
104 |
+
|
105 |
+
|
106 |
+
def read_cameras_text(path):
|
107 |
+
"""
|
108 |
+
see: src/colmap/scene/reconstruction.cc
|
109 |
+
void Reconstruction::WriteCamerasText(const std::string& path)
|
110 |
+
void Reconstruction::ReadCamerasText(const std::string& path)
|
111 |
+
"""
|
112 |
+
cameras = {}
|
113 |
+
with open(path, "r") as fid:
|
114 |
+
while True:
|
115 |
+
line = fid.readline()
|
116 |
+
if not line:
|
117 |
+
break
|
118 |
+
line = line.strip()
|
119 |
+
if len(line) > 0 and line[0] != "#":
|
120 |
+
elems = line.split()
|
121 |
+
camera_id = int(elems[0])
|
122 |
+
model = elems[1]
|
123 |
+
width = int(elems[2])
|
124 |
+
height = int(elems[3])
|
125 |
+
params = np.array(tuple(map(float, elems[4:])))
|
126 |
+
cameras[camera_id] = Camera(
|
127 |
+
id=camera_id,
|
128 |
+
model=model,
|
129 |
+
width=width,
|
130 |
+
height=height,
|
131 |
+
params=params,
|
132 |
+
)
|
133 |
+
return cameras
|
134 |
+
|
135 |
+
|
136 |
+
def read_cameras_binary(path_to_model_file):
|
137 |
+
"""
|
138 |
+
see: src/colmap/scene/reconstruction.cc
|
139 |
+
void Reconstruction::WriteCamerasBinary(const std::string& path)
|
140 |
+
void Reconstruction::ReadCamerasBinary(const std::string& path)
|
141 |
+
"""
|
142 |
+
cameras = {}
|
143 |
+
with open(path_to_model_file, "rb") as fid:
|
144 |
+
num_cameras = read_next_bytes(fid, 8, "Q")[0]
|
145 |
+
for _ in range(num_cameras):
|
146 |
+
camera_properties = read_next_bytes(
|
147 |
+
fid, num_bytes=24, format_char_sequence="iiQQ"
|
148 |
+
)
|
149 |
+
camera_id = camera_properties[0]
|
150 |
+
model_id = camera_properties[1]
|
151 |
+
model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
|
152 |
+
width = camera_properties[2]
|
153 |
+
height = camera_properties[3]
|
154 |
+
num_params = CAMERA_MODEL_IDS[model_id].num_params
|
155 |
+
params = read_next_bytes(
|
156 |
+
fid,
|
157 |
+
num_bytes=8 * num_params,
|
158 |
+
format_char_sequence="d" * num_params,
|
159 |
+
)
|
160 |
+
cameras[camera_id] = Camera(
|
161 |
+
id=camera_id,
|
162 |
+
model=model_name,
|
163 |
+
width=width,
|
164 |
+
height=height,
|
165 |
+
params=np.array(params),
|
166 |
+
)
|
167 |
+
assert len(cameras) == num_cameras
|
168 |
+
return cameras
|
169 |
+
|
170 |
+
|
171 |
+
def write_cameras_text(cameras, path):
|
172 |
+
"""
|
173 |
+
see: src/colmap/scene/reconstruction.cc
|
174 |
+
void Reconstruction::WriteCamerasText(const std::string& path)
|
175 |
+
void Reconstruction::ReadCamerasText(const std::string& path)
|
176 |
+
"""
|
177 |
+
HEADER = (
|
178 |
+
"# Camera list with one line of data per camera:\n"
|
179 |
+
+ "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n"
|
180 |
+
+ "# Number of cameras: {}\n".format(len(cameras))
|
181 |
+
)
|
182 |
+
with open(path, "w") as fid:
|
183 |
+
fid.write(HEADER)
|
184 |
+
for _, cam in cameras.items():
|
185 |
+
to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params]
|
186 |
+
line = " ".join([str(elem) for elem in to_write])
|
187 |
+
fid.write(line + "\n")
|
188 |
+
|
189 |
+
|
190 |
+
def write_cameras_binary(cameras, path_to_model_file):
|
191 |
+
"""
|
192 |
+
see: src/colmap/scene/reconstruction.cc
|
193 |
+
void Reconstruction::WriteCamerasBinary(const std::string& path)
|
194 |
+
void Reconstruction::ReadCamerasBinary(const std::string& path)
|
195 |
+
"""
|
196 |
+
with open(path_to_model_file, "wb") as fid:
|
197 |
+
write_next_bytes(fid, len(cameras), "Q")
|
198 |
+
for _, cam in cameras.items():
|
199 |
+
model_id = CAMERA_MODEL_NAMES[cam.model].model_id
|
200 |
+
camera_properties = [cam.id, model_id, cam.width, cam.height]
|
201 |
+
write_next_bytes(fid, camera_properties, "iiQQ")
|
202 |
+
for p in cam.params:
|
203 |
+
write_next_bytes(fid, float(p), "d")
|
204 |
+
return cameras
|
205 |
+
|
206 |
+
|
207 |
+
def read_images_text(path):
|
208 |
+
"""
|
209 |
+
see: src/colmap/scene/reconstruction.cc
|
210 |
+
void Reconstruction::ReadImagesText(const std::string& path)
|
211 |
+
void Reconstruction::WriteImagesText(const std::string& path)
|
212 |
+
"""
|
213 |
+
images = {}
|
214 |
+
with open(path, "r") as fid:
|
215 |
+
while True:
|
216 |
+
line = fid.readline()
|
217 |
+
if not line:
|
218 |
+
break
|
219 |
+
line = line.strip()
|
220 |
+
if len(line) > 0 and line[0] != "#":
|
221 |
+
elems = line.split()
|
222 |
+
image_id = int(elems[0])
|
223 |
+
qvec = np.array(tuple(map(float, elems[1:5])))
|
224 |
+
tvec = np.array(tuple(map(float, elems[5:8])))
|
225 |
+
camera_id = int(elems[8])
|
226 |
+
image_name = elems[9]
|
227 |
+
elems = fid.readline().split()
|
228 |
+
xys = np.column_stack(
|
229 |
+
[
|
230 |
+
tuple(map(float, elems[0::3])),
|
231 |
+
tuple(map(float, elems[1::3])),
|
232 |
+
]
|
233 |
+
)
|
234 |
+
point3D_ids = np.array(tuple(map(int, elems[2::3])))
|
235 |
+
images[image_id] = Image(
|
236 |
+
id=image_id,
|
237 |
+
qvec=qvec,
|
238 |
+
tvec=tvec,
|
239 |
+
camera_id=camera_id,
|
240 |
+
name=image_name,
|
241 |
+
xys=xys,
|
242 |
+
point3D_ids=point3D_ids,
|
243 |
+
)
|
244 |
+
return images
|
245 |
+
|
246 |
+
|
247 |
+
def read_images_binary(path_to_model_file):
|
248 |
+
"""
|
249 |
+
see: src/colmap/scene/reconstruction.cc
|
250 |
+
void Reconstruction::ReadImagesBinary(const std::string& path)
|
251 |
+
void Reconstruction::WriteImagesBinary(const std::string& path)
|
252 |
+
"""
|
253 |
+
images = {}
|
254 |
+
with open(path_to_model_file, "rb") as fid:
|
255 |
+
num_reg_images = read_next_bytes(fid, 8, "Q")[0]
|
256 |
+
for _ in range(num_reg_images):
|
257 |
+
binary_image_properties = read_next_bytes(
|
258 |
+
fid, num_bytes=64, format_char_sequence="idddddddi"
|
259 |
+
)
|
260 |
+
image_id = binary_image_properties[0]
|
261 |
+
qvec = np.array(binary_image_properties[1:5])
|
262 |
+
tvec = np.array(binary_image_properties[5:8])
|
263 |
+
camera_id = binary_image_properties[8]
|
264 |
+
binary_image_name = b""
|
265 |
+
current_char = read_next_bytes(fid, 1, "c")[0]
|
266 |
+
while current_char != b"\x00": # look for the ASCII 0 entry
|
267 |
+
binary_image_name += current_char
|
268 |
+
current_char = read_next_bytes(fid, 1, "c")[0]
|
269 |
+
image_name = binary_image_name.decode("utf-8")
|
270 |
+
num_points2D = read_next_bytes(
|
271 |
+
fid, num_bytes=8, format_char_sequence="Q"
|
272 |
+
)[0]
|
273 |
+
x_y_id_s = read_next_bytes(
|
274 |
+
fid,
|
275 |
+
num_bytes=24 * num_points2D,
|
276 |
+
format_char_sequence="ddq" * num_points2D,
|
277 |
+
)
|
278 |
+
xys = np.column_stack(
|
279 |
+
[
|
280 |
+
tuple(map(float, x_y_id_s[0::3])),
|
281 |
+
tuple(map(float, x_y_id_s[1::3])),
|
282 |
+
]
|
283 |
+
)
|
284 |
+
point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
|
285 |
+
images[image_id] = Image(
|
286 |
+
id=image_id,
|
287 |
+
qvec=qvec,
|
288 |
+
tvec=tvec,
|
289 |
+
camera_id=camera_id,
|
290 |
+
name=image_name,
|
291 |
+
xys=xys,
|
292 |
+
point3D_ids=point3D_ids,
|
293 |
+
)
|
294 |
+
return images
|
295 |
+
|
296 |
+
|
297 |
+
def write_images_text(images, path):
|
298 |
+
"""
|
299 |
+
see: src/colmap/scene/reconstruction.cc
|
300 |
+
void Reconstruction::ReadImagesText(const std::string& path)
|
301 |
+
void Reconstruction::WriteImagesText(const std::string& path)
|
302 |
+
"""
|
303 |
+
if len(images) == 0:
|
304 |
+
mean_observations = 0
|
305 |
+
else:
|
306 |
+
mean_observations = sum(
|
307 |
+
(len(img.point3D_ids) for _, img in images.items())
|
308 |
+
) / len(images)
|
309 |
+
HEADER = (
|
310 |
+
"# Image list with two lines of data per image:\n"
|
311 |
+
+ "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n"
|
312 |
+
+ "# POINTS2D[] as (X, Y, POINT3D_ID)\n"
|
313 |
+
+ "# Number of images: {}, mean observations per image: {}\n".format(
|
314 |
+
len(images), mean_observations
|
315 |
+
)
|
316 |
+
)
|
317 |
+
|
318 |
+
with open(path, "w") as fid:
|
319 |
+
fid.write(HEADER)
|
320 |
+
for _, img in images.items():
|
321 |
+
image_header = [
|
322 |
+
img.id,
|
323 |
+
*img.qvec,
|
324 |
+
*img.tvec,
|
325 |
+
img.camera_id,
|
326 |
+
img.name,
|
327 |
+
]
|
328 |
+
first_line = " ".join(map(str, image_header))
|
329 |
+
fid.write(first_line + "\n")
|
330 |
+
|
331 |
+
points_strings = []
|
332 |
+
for xy, point3D_id in zip(img.xys, img.point3D_ids):
|
333 |
+
points_strings.append(" ".join(map(str, [*xy, point3D_id])))
|
334 |
+
fid.write(" ".join(points_strings) + "\n")
|
335 |
+
|
336 |
+
|
337 |
+
def write_images_binary(images, path_to_model_file):
|
338 |
+
"""
|
339 |
+
see: src/colmap/scene/reconstruction.cc
|
340 |
+
void Reconstruction::ReadImagesBinary(const std::string& path)
|
341 |
+
void Reconstruction::WriteImagesBinary(const std::string& path)
|
342 |
+
"""
|
343 |
+
with open(path_to_model_file, "wb") as fid:
|
344 |
+
write_next_bytes(fid, len(images), "Q")
|
345 |
+
for _, img in images.items():
|
346 |
+
write_next_bytes(fid, img.id, "i")
|
347 |
+
write_next_bytes(fid, img.qvec.tolist(), "dddd")
|
348 |
+
write_next_bytes(fid, img.tvec.tolist(), "ddd")
|
349 |
+
write_next_bytes(fid, img.camera_id, "i")
|
350 |
+
for char in img.name:
|
351 |
+
write_next_bytes(fid, char.encode("utf-8"), "c")
|
352 |
+
write_next_bytes(fid, b"\x00", "c")
|
353 |
+
write_next_bytes(fid, len(img.point3D_ids), "Q")
|
354 |
+
for xy, p3d_id in zip(img.xys, img.point3D_ids):
|
355 |
+
write_next_bytes(fid, [*xy, p3d_id], "ddq")
|
356 |
+
|
357 |
+
|
358 |
+
def read_points3D_text(path):
|
359 |
+
"""
|
360 |
+
see: src/colmap/scene/reconstruction.cc
|
361 |
+
void Reconstruction::ReadPoints3DText(const std::string& path)
|
362 |
+
void Reconstruction::WritePoints3DText(const std::string& path)
|
363 |
+
"""
|
364 |
+
points3D = {}
|
365 |
+
with open(path, "r") as fid:
|
366 |
+
while True:
|
367 |
+
line = fid.readline()
|
368 |
+
if not line:
|
369 |
+
break
|
370 |
+
line = line.strip()
|
371 |
+
if len(line) > 0 and line[0] != "#":
|
372 |
+
elems = line.split()
|
373 |
+
point3D_id = int(elems[0])
|
374 |
+
xyz = np.array(tuple(map(float, elems[1:4])))
|
375 |
+
rgb = np.array(tuple(map(int, elems[4:7])))
|
376 |
+
error = float(elems[7])
|
377 |
+
image_ids = np.array(tuple(map(int, elems[8::2])))
|
378 |
+
point2D_idxs = np.array(tuple(map(int, elems[9::2])))
|
379 |
+
points3D[point3D_id] = Point3D(
|
380 |
+
id=point3D_id,
|
381 |
+
xyz=xyz,
|
382 |
+
rgb=rgb,
|
383 |
+
error=error,
|
384 |
+
image_ids=image_ids,
|
385 |
+
point2D_idxs=point2D_idxs,
|
386 |
+
)
|
387 |
+
return points3D
|
388 |
+
|
389 |
+
|
390 |
+
def read_points3D_binary(path_to_model_file):
|
391 |
+
"""
|
392 |
+
see: src/colmap/scene/reconstruction.cc
|
393 |
+
void Reconstruction::ReadPoints3DBinary(const std::string& path)
|
394 |
+
void Reconstruction::WritePoints3DBinary(const std::string& path)
|
395 |
+
"""
|
396 |
+
points3D = {}
|
397 |
+
with open(path_to_model_file, "rb") as fid:
|
398 |
+
num_points = read_next_bytes(fid, 8, "Q")[0]
|
399 |
+
for _ in range(num_points):
|
400 |
+
binary_point_line_properties = read_next_bytes(
|
401 |
+
fid, num_bytes=43, format_char_sequence="QdddBBBd"
|
402 |
+
)
|
403 |
+
point3D_id = binary_point_line_properties[0]
|
404 |
+
xyz = np.array(binary_point_line_properties[1:4])
|
405 |
+
rgb = np.array(binary_point_line_properties[4:7])
|
406 |
+
error = np.array(binary_point_line_properties[7])
|
407 |
+
track_length = read_next_bytes(
|
408 |
+
fid, num_bytes=8, format_char_sequence="Q"
|
409 |
+
)[0]
|
410 |
+
track_elems = read_next_bytes(
|
411 |
+
fid,
|
412 |
+
num_bytes=8 * track_length,
|
413 |
+
format_char_sequence="ii" * track_length,
|
414 |
+
)
|
415 |
+
image_ids = np.array(tuple(map(int, track_elems[0::2])))
|
416 |
+
point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
|
417 |
+
points3D[point3D_id] = Point3D(
|
418 |
+
id=point3D_id,
|
419 |
+
xyz=xyz,
|
420 |
+
rgb=rgb,
|
421 |
+
error=error,
|
422 |
+
image_ids=image_ids,
|
423 |
+
point2D_idxs=point2D_idxs,
|
424 |
+
)
|
425 |
+
return points3D
|
426 |
+
|
427 |
+
|
428 |
+
def write_points3D_text(points3D, path):
|
429 |
+
"""
|
430 |
+
see: src/colmap/scene/reconstruction.cc
|
431 |
+
void Reconstruction::ReadPoints3DText(const std::string& path)
|
432 |
+
void Reconstruction::WritePoints3DText(const std::string& path)
|
433 |
+
"""
|
434 |
+
if len(points3D) == 0:
|
435 |
+
mean_track_length = 0
|
436 |
+
else:
|
437 |
+
mean_track_length = sum(
|
438 |
+
(len(pt.image_ids) for _, pt in points3D.items())
|
439 |
+
) / len(points3D)
|
440 |
+
HEADER = (
|
441 |
+
"# 3D point list with one line of data per point:\n"
|
442 |
+
+ "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n"
|
443 |
+
+ "# Number of points: {}, mean track length: {}\n".format(
|
444 |
+
len(points3D), mean_track_length
|
445 |
+
)
|
446 |
+
)
|
447 |
+
|
448 |
+
with open(path, "w") as fid:
|
449 |
+
fid.write(HEADER)
|
450 |
+
for _, pt in points3D.items():
|
451 |
+
point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error]
|
452 |
+
fid.write(" ".join(map(str, point_header)) + " ")
|
453 |
+
track_strings = []
|
454 |
+
for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs):
|
455 |
+
track_strings.append(" ".join(map(str, [image_id, point2D])))
|
456 |
+
fid.write(" ".join(track_strings) + "\n")
|
457 |
+
|
458 |
+
|
459 |
+
def write_points3D_binary(points3D, path_to_model_file):
|
460 |
+
"""
|
461 |
+
see: src/colmap/scene/reconstruction.cc
|
462 |
+
void Reconstruction::ReadPoints3DBinary(const std::string& path)
|
463 |
+
void Reconstruction::WritePoints3DBinary(const std::string& path)
|
464 |
+
"""
|
465 |
+
with open(path_to_model_file, "wb") as fid:
|
466 |
+
write_next_bytes(fid, len(points3D), "Q")
|
467 |
+
for _, pt in points3D.items():
|
468 |
+
write_next_bytes(fid, pt.id, "Q")
|
469 |
+
write_next_bytes(fid, pt.xyz.tolist(), "ddd")
|
470 |
+
write_next_bytes(fid, pt.rgb.tolist(), "BBB")
|
471 |
+
write_next_bytes(fid, pt.error, "d")
|
472 |
+
track_length = pt.image_ids.shape[0]
|
473 |
+
write_next_bytes(fid, track_length, "Q")
|
474 |
+
for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):
|
475 |
+
write_next_bytes(fid, [image_id, point2D_id], "ii")
|
476 |
+
|
477 |
+
|
478 |
+
def detect_model_format(path, ext):
|
479 |
+
if (
|
480 |
+
os.path.isfile(os.path.join(path, "cameras" + ext))
|
481 |
+
and os.path.isfile(os.path.join(path, "images" + ext))
|
482 |
+
and os.path.isfile(os.path.join(path, "points3D" + ext))
|
483 |
+
):
|
484 |
+
print("Detected model format: '" + ext + "'")
|
485 |
+
return True
|
486 |
+
|
487 |
+
return False
|
488 |
+
|
489 |
+
|
490 |
+
def read_model(path, ext=""):
|
491 |
+
# try to detect the extension automatically
|
492 |
+
if ext == "":
|
493 |
+
if detect_model_format(path, ".bin"):
|
494 |
+
ext = ".bin"
|
495 |
+
elif detect_model_format(path, ".txt"):
|
496 |
+
ext = ".txt"
|
497 |
+
else:
|
498 |
+
print("Provide model format: '.bin' or '.txt'")
|
499 |
+
return
|
500 |
+
|
501 |
+
if ext == ".txt":
|
502 |
+
cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
|
503 |
+
images = read_images_text(os.path.join(path, "images" + ext))
|
504 |
+
points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
|
505 |
+
else:
|
506 |
+
cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
|
507 |
+
images = read_images_binary(os.path.join(path, "images" + ext))
|
508 |
+
points3D = read_points3D_binary(os.path.join(path, "points3D") + ext)
|
509 |
+
return cameras, images, points3D
|
510 |
+
|
511 |
+
|
512 |
+
def write_model(cameras, images, points3D, path, ext=".bin"):
|
513 |
+
if ext == ".txt":
|
514 |
+
write_cameras_text(cameras, os.path.join(path, "cameras" + ext))
|
515 |
+
write_images_text(images, os.path.join(path, "images" + ext))
|
516 |
+
write_points3D_text(points3D, os.path.join(path, "points3D") + ext)
|
517 |
+
else:
|
518 |
+
write_cameras_binary(cameras, os.path.join(path, "cameras" + ext))
|
519 |
+
write_images_binary(images, os.path.join(path, "images" + ext))
|
520 |
+
write_points3D_binary(points3D, os.path.join(path, "points3D") + ext)
|
521 |
+
return cameras, images, points3D
|
522 |
+
|
523 |
+
|
524 |
+
def qvec2rotmat(qvec):
|
525 |
+
return np.array(
|
526 |
+
[
|
527 |
+
[
|
528 |
+
1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
|
529 |
+
2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
|
530 |
+
2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],
|
531 |
+
],
|
532 |
+
[
|
533 |
+
2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
|
534 |
+
1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
|
535 |
+
2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],
|
536 |
+
],
|
537 |
+
[
|
538 |
+
2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
|
539 |
+
2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
|
540 |
+
1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,
|
541 |
+
],
|
542 |
+
]
|
543 |
+
)
|
544 |
+
|
545 |
+
|
546 |
+
def rotmat2qvec(R):
|
547 |
+
Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
|
548 |
+
K = (
|
549 |
+
np.array(
|
550 |
+
[
|
551 |
+
[Rxx - Ryy - Rzz, 0, 0, 0],
|
552 |
+
[Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
|
553 |
+
[Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
|
554 |
+
[Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz],
|
555 |
+
]
|
556 |
+
)
|
557 |
+
/ 3.0
|
558 |
+
)
|
559 |
+
eigvals, eigvecs = np.linalg.eigh(K)
|
560 |
+
qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
|
561 |
+
if qvec[0] < 0:
|
562 |
+
qvec *= -1
|
563 |
+
return qvec
|
564 |
+
|
565 |
+
|
566 |
+
def main():
|
567 |
+
parser = argparse.ArgumentParser(
|
568 |
+
description="Read and write COLMAP binary and text models"
|
569 |
+
)
|
570 |
+
parser.add_argument("--input_model", help="path to input model folder")
|
571 |
+
parser.add_argument(
|
572 |
+
"--input_format",
|
573 |
+
choices=[".bin", ".txt"],
|
574 |
+
help="input model format",
|
575 |
+
default="",
|
576 |
+
)
|
577 |
+
parser.add_argument("--output_model", help="path to output model folder")
|
578 |
+
parser.add_argument(
|
579 |
+
"--output_format",
|
580 |
+
choices=[".bin", ".txt"],
|
581 |
+
help="outut model format",
|
582 |
+
default=".txt",
|
583 |
+
)
|
584 |
+
args = parser.parse_args()
|
585 |
+
|
586 |
+
cameras, images, points3D = read_model(
|
587 |
+
path=args.input_model, ext=args.input_format
|
588 |
+
)
|
589 |
+
|
590 |
+
print("num_cameras:", len(cameras))
|
591 |
+
print("num_images:", len(images))
|
592 |
+
print("num_points3D:", len(points3D))
|
593 |
+
|
594 |
+
if args.output_model is not None:
|
595 |
+
write_model(
|
596 |
+
cameras,
|
597 |
+
images,
|
598 |
+
points3D,
|
599 |
+
path=args.output_model,
|
600 |
+
ext=args.output_format,
|
601 |
+
)
|
602 |
+
|
603 |
+
|
604 |
+
if __name__ == "__main__":
|
605 |
+
main()
|
sgm/data/dataset.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torchdata.datapipes.iter
|
4 |
+
import webdataset as wds
|
5 |
+
from omegaconf import DictConfig
|
6 |
+
from pytorch_lightning import LightningDataModule
|
7 |
+
|
8 |
+
try:
|
9 |
+
from sdata import create_dataset, create_dummy_dataset, create_loader
|
10 |
+
except ImportError as e:
|
11 |
+
print("#" * 100)
|
12 |
+
print("Datasets not yet available")
|
13 |
+
print("to enable, we need to add stable-datasets as a submodule")
|
14 |
+
print("please use ``git submodule update --init --recursive``")
|
15 |
+
print("and do ``pip install -e stable-datasets/`` from the root of this repo")
|
16 |
+
print("#" * 100)
|
17 |
+
exit(1)
|
18 |
+
|
19 |
+
|
20 |
+
class StableDataModuleFromConfig(LightningDataModule):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
train: DictConfig,
|
24 |
+
validation: Optional[DictConfig] = None,
|
25 |
+
test: Optional[DictConfig] = None,
|
26 |
+
skip_val_loader: bool = False,
|
27 |
+
dummy: bool = False,
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
self.train_config = train
|
31 |
+
assert (
|
32 |
+
"datapipeline" in self.train_config and "loader" in self.train_config
|
33 |
+
), "train config requires the fields `datapipeline` and `loader`"
|
34 |
+
|
35 |
+
self.val_config = validation
|
36 |
+
if not skip_val_loader:
|
37 |
+
if self.val_config is not None:
|
38 |
+
assert (
|
39 |
+
"datapipeline" in self.val_config and "loader" in self.val_config
|
40 |
+
), "validation config requires the fields `datapipeline` and `loader`"
|
41 |
+
else:
|
42 |
+
print(
|
43 |
+
"Warning: No Validation datapipeline defined, using that one from training"
|
44 |
+
)
|
45 |
+
self.val_config = train
|
46 |
+
|
47 |
+
self.test_config = test
|
48 |
+
if self.test_config is not None:
|
49 |
+
assert (
|
50 |
+
"datapipeline" in self.test_config and "loader" in self.test_config
|
51 |
+
), "test config requires the fields `datapipeline` and `loader`"
|
52 |
+
|
53 |
+
self.dummy = dummy
|
54 |
+
if self.dummy:
|
55 |
+
print("#" * 100)
|
56 |
+
print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
|
57 |
+
print("#" * 100)
|
58 |
+
|
59 |
+
def setup(self, stage: str) -> None:
|
60 |
+
print("Preparing datasets")
|
61 |
+
if self.dummy:
|
62 |
+
data_fn = create_dummy_dataset
|
63 |
+
else:
|
64 |
+
data_fn = create_dataset
|
65 |
+
|
66 |
+
self.train_datapipeline = data_fn(**self.train_config.datapipeline)
|
67 |
+
if self.val_config:
|
68 |
+
self.val_datapipeline = data_fn(**self.val_config.datapipeline)
|
69 |
+
if self.test_config:
|
70 |
+
self.test_datapipeline = data_fn(**self.test_config.datapipeline)
|
71 |
+
|
72 |
+
def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
|
73 |
+
loader = create_loader(self.train_datapipeline, **self.train_config.loader)
|
74 |
+
return loader
|
75 |
+
|
76 |
+
def val_dataloader(self) -> wds.DataPipeline:
|
77 |
+
return create_loader(self.val_datapipeline, **self.val_config.loader)
|
78 |
+
|
79 |
+
def test_dataloader(self) -> wds.DataPipeline:
|
80 |
+
return create_loader(self.test_datapipeline, **self.test_config.loader)
|
sgm/data/joint3d.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
|
4 |
+
default_sub_data_config = {}
|
5 |
+
|
6 |
+
|
7 |
+
class Joint3D(Dataset):
|
8 |
+
def __init__(self, sub_data_config: dict) -> None:
|
9 |
+
super().__init__()
|
10 |
+
self.sub_data_config = sub_data_config
|
sgm/data/json_index_dataset.py
ADDED
@@ -0,0 +1,1080 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import copy
|
8 |
+
import functools
|
9 |
+
import gzip
|
10 |
+
import hashlib
|
11 |
+
import json
|
12 |
+
import logging
|
13 |
+
import os
|
14 |
+
import random
|
15 |
+
import warnings
|
16 |
+
from collections import defaultdict
|
17 |
+
from itertools import islice
|
18 |
+
from pathlib import Path
|
19 |
+
from typing import (
|
20 |
+
Any,
|
21 |
+
ClassVar,
|
22 |
+
Dict,
|
23 |
+
Iterable,
|
24 |
+
List,
|
25 |
+
Optional,
|
26 |
+
Sequence,
|
27 |
+
Tuple,
|
28 |
+
Type,
|
29 |
+
TYPE_CHECKING,
|
30 |
+
Union,
|
31 |
+
)
|
32 |
+
|
33 |
+
import numpy as np
|
34 |
+
import torch
|
35 |
+
from PIL import Image
|
36 |
+
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
37 |
+
from pytorch3d.io import IO
|
38 |
+
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
39 |
+
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
40 |
+
from pytorch3d.structures.pointclouds import Pointclouds
|
41 |
+
from tqdm import tqdm
|
42 |
+
|
43 |
+
from pytorch3d.implicitron.dataset import types
|
44 |
+
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
|
45 |
+
from pytorch3d.implicitron.dataset.utils import is_known_frame_scalar
|
46 |
+
|
47 |
+
|
48 |
+
logger = logging.getLogger(__name__)
|
49 |
+
|
50 |
+
|
51 |
+
if TYPE_CHECKING:
|
52 |
+
from typing import TypedDict
|
53 |
+
|
54 |
+
class FrameAnnotsEntry(TypedDict):
|
55 |
+
subset: Optional[str]
|
56 |
+
frame_annotation: types.FrameAnnotation
|
57 |
+
|
58 |
+
else:
|
59 |
+
FrameAnnotsEntry = dict
|
60 |
+
|
61 |
+
|
62 |
+
@registry.register
|
63 |
+
class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
64 |
+
"""
|
65 |
+
A dataset with annotations in json files like the Common Objects in 3D
|
66 |
+
(CO3D) dataset.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
frame_annotations_file: A zipped json file containing metadata of the
|
70 |
+
frames in the dataset, serialized List[types.FrameAnnotation].
|
71 |
+
sequence_annotations_file: A zipped json file containing metadata of the
|
72 |
+
sequences in the dataset, serialized List[types.SequenceAnnotation].
|
73 |
+
subset_lists_file: A json file containing the lists of frames corresponding
|
74 |
+
corresponding to different subsets (e.g. train/val/test) of the dataset;
|
75 |
+
format: {subset: (sequence_name, frame_id, file_path)}.
|
76 |
+
subsets: Restrict frames/sequences only to the given list of subsets
|
77 |
+
as defined in subset_lists_file (see above).
|
78 |
+
limit_to: Limit the dataset to the first #limit_to frames (after other
|
79 |
+
filters have been applied).
|
80 |
+
limit_sequences_to: Limit the dataset to the first
|
81 |
+
#limit_sequences_to sequences (after other sequence filters have been
|
82 |
+
applied but before frame-based filters).
|
83 |
+
pick_sequence: A list of sequence names to restrict the dataset to.
|
84 |
+
exclude_sequence: A list of the names of the sequences to exclude.
|
85 |
+
limit_category_to: Restrict the dataset to the given list of categories.
|
86 |
+
dataset_root: The root folder of the dataset; all the paths in jsons are
|
87 |
+
specified relative to this root (but not json paths themselves).
|
88 |
+
load_images: Enable loading the frame RGB data.
|
89 |
+
load_depths: Enable loading the frame depth maps.
|
90 |
+
load_depth_masks: Enable loading the frame depth map masks denoting the
|
91 |
+
depth values used for evaluation (the points consistent across views).
|
92 |
+
load_masks: Enable loading frame foreground masks.
|
93 |
+
load_point_clouds: Enable loading sequence-level point clouds.
|
94 |
+
max_points: Cap on the number of loaded points in the point cloud;
|
95 |
+
if reached, they are randomly sampled without replacement.
|
96 |
+
mask_images: Whether to mask the images with the loaded foreground masks;
|
97 |
+
0 value is used for background.
|
98 |
+
mask_depths: Whether to mask the depth maps with the loaded foreground
|
99 |
+
masks; 0 value is used for background.
|
100 |
+
image_height: The height of the returned images, masks, and depth maps;
|
101 |
+
aspect ratio is preserved during cropping/resizing.
|
102 |
+
image_width: The width of the returned images, masks, and depth maps;
|
103 |
+
aspect ratio is preserved during cropping/resizing.
|
104 |
+
box_crop: Enable cropping of the image around the bounding box inferred
|
105 |
+
from the foreground region of the loaded segmentation mask; masks
|
106 |
+
and depth maps are cropped accordingly; cameras are corrected.
|
107 |
+
box_crop_mask_thr: The threshold used to separate pixels into foreground
|
108 |
+
and background based on the foreground_probability mask; if no value
|
109 |
+
is greater than this threshold, the loader lowers it and repeats.
|
110 |
+
box_crop_context: The amount of additional padding added to each
|
111 |
+
dimension of the cropping bounding box, relative to box size.
|
112 |
+
remove_empty_masks: Removes the frames with no active foreground pixels
|
113 |
+
in the segmentation mask after thresholding (see box_crop_mask_thr).
|
114 |
+
n_frames_per_sequence: If > 0, randomly samples #n_frames_per_sequence
|
115 |
+
frames in each sequences uniformly without replacement if it has
|
116 |
+
more frames than that; applied before other frame-level filters.
|
117 |
+
seed: The seed of the random generator sampling #n_frames_per_sequence
|
118 |
+
random frames per sequence.
|
119 |
+
sort_frames: Enable frame annotations sorting to group frames from the
|
120 |
+
same sequences together and order them by timestamps
|
121 |
+
eval_batches: A list of batches that form the evaluation set;
|
122 |
+
list of batch-sized lists of indices corresponding to __getitem__
|
123 |
+
of this class, thus it can be used directly as a batch sampler.
|
124 |
+
eval_batch_index:
|
125 |
+
( Optional[List[List[Union[Tuple[str, int, str], Tuple[str, int]]]] )
|
126 |
+
A list of batches of frames described as (sequence_name, frame_idx)
|
127 |
+
that can form the evaluation set, `eval_batches` will be set from this.
|
128 |
+
|
129 |
+
"""
|
130 |
+
|
131 |
+
frame_annotations_type: ClassVar[
|
132 |
+
Type[types.FrameAnnotation]
|
133 |
+
] = types.FrameAnnotation
|
134 |
+
|
135 |
+
path_manager: Any = None
|
136 |
+
frame_annotations_file: str = ""
|
137 |
+
sequence_annotations_file: str = ""
|
138 |
+
subset_lists_file: str = ""
|
139 |
+
subsets: Optional[List[str]] = None
|
140 |
+
limit_to: int = 0
|
141 |
+
limit_sequences_to: int = 0
|
142 |
+
pick_sequence: Tuple[str, ...] = ()
|
143 |
+
exclude_sequence: Tuple[str, ...] = ()
|
144 |
+
limit_category_to: Tuple[int, ...] = ()
|
145 |
+
dataset_root: str = ""
|
146 |
+
load_images: bool = True
|
147 |
+
load_depths: bool = True
|
148 |
+
load_depth_masks: bool = True
|
149 |
+
load_masks: bool = True
|
150 |
+
load_point_clouds: bool = False
|
151 |
+
max_points: int = 0
|
152 |
+
mask_images: bool = False
|
153 |
+
mask_depths: bool = False
|
154 |
+
image_height: Optional[int] = 800
|
155 |
+
image_width: Optional[int] = 800
|
156 |
+
box_crop: bool = True
|
157 |
+
box_crop_mask_thr: float = 0.4
|
158 |
+
box_crop_context: float = 0.3
|
159 |
+
remove_empty_masks: bool = True
|
160 |
+
n_frames_per_sequence: int = -1
|
161 |
+
seed: int = 0
|
162 |
+
sort_frames: bool = False
|
163 |
+
eval_batches: Any = None
|
164 |
+
eval_batch_index: Any = None
|
165 |
+
# frame_annots: List[FrameAnnotsEntry] = field(init=False)
|
166 |
+
# seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
|
167 |
+
|
168 |
+
def __post_init__(self) -> None:
|
169 |
+
# pyre-fixme[16]: `JsonIndexDataset` has no attribute `subset_to_image_path`.
|
170 |
+
self.subset_to_image_path = None
|
171 |
+
self._load_frames()
|
172 |
+
self._load_sequences()
|
173 |
+
if self.sort_frames:
|
174 |
+
self._sort_frames()
|
175 |
+
self._load_subset_lists()
|
176 |
+
self._filter_db() # also computes sequence indices
|
177 |
+
self._extract_and_set_eval_batches()
|
178 |
+
logger.info(str(self))
|
179 |
+
|
180 |
+
def _extract_and_set_eval_batches(self):
|
181 |
+
"""
|
182 |
+
Sets eval_batches based on input eval_batch_index.
|
183 |
+
"""
|
184 |
+
if self.eval_batch_index is not None:
|
185 |
+
if self.eval_batches is not None:
|
186 |
+
raise ValueError(
|
187 |
+
"Cannot define both eval_batch_index and eval_batches."
|
188 |
+
)
|
189 |
+
self.eval_batches = self.seq_frame_index_to_dataset_index(
|
190 |
+
self.eval_batch_index
|
191 |
+
)
|
192 |
+
|
193 |
+
def join(self, other_datasets: Iterable[DatasetBase]) -> None:
|
194 |
+
"""
|
195 |
+
Join the dataset with other JsonIndexDataset objects.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
other_datasets: A list of JsonIndexDataset objects to be joined
|
199 |
+
into the current dataset.
|
200 |
+
"""
|
201 |
+
if not all(isinstance(d, JsonIndexDataset) for d in other_datasets):
|
202 |
+
raise ValueError("This function can only join a list of JsonIndexDataset")
|
203 |
+
# pyre-ignore[16]
|
204 |
+
self.frame_annots.extend([fa for d in other_datasets for fa in d.frame_annots])
|
205 |
+
# pyre-ignore[16]
|
206 |
+
self.seq_annots.update(
|
207 |
+
# https://gist.github.com/treyhunner/f35292e676efa0be1728
|
208 |
+
functools.reduce(
|
209 |
+
lambda a, b: {**a, **b},
|
210 |
+
[d.seq_annots for d in other_datasets], # pyre-ignore[16]
|
211 |
+
)
|
212 |
+
)
|
213 |
+
all_eval_batches = [
|
214 |
+
self.eval_batches,
|
215 |
+
# pyre-ignore
|
216 |
+
*[d.eval_batches for d in other_datasets],
|
217 |
+
]
|
218 |
+
if not (
|
219 |
+
all(ba is None for ba in all_eval_batches)
|
220 |
+
or all(ba is not None for ba in all_eval_batches)
|
221 |
+
):
|
222 |
+
raise ValueError(
|
223 |
+
"When joining datasets, either all joined datasets have to have their"
|
224 |
+
" eval_batches defined, or all should have their eval batches undefined."
|
225 |
+
)
|
226 |
+
if self.eval_batches is not None:
|
227 |
+
self.eval_batches = sum(all_eval_batches, [])
|
228 |
+
self._invalidate_indexes(filter_seq_annots=True)
|
229 |
+
|
230 |
+
def is_filtered(self) -> bool:
|
231 |
+
"""
|
232 |
+
Returns `True` in case the dataset has been filtered and thus some frame annotations
|
233 |
+
stored on the disk might be missing in the dataset object.
|
234 |
+
|
235 |
+
Returns:
|
236 |
+
is_filtered: `True` if the dataset has been filtered, else `False`.
|
237 |
+
"""
|
238 |
+
return (
|
239 |
+
self.remove_empty_masks
|
240 |
+
or self.limit_to > 0
|
241 |
+
or self.limit_sequences_to > 0
|
242 |
+
or len(self.pick_sequence) > 0
|
243 |
+
or len(self.exclude_sequence) > 0
|
244 |
+
or len(self.limit_category_to) > 0
|
245 |
+
or self.n_frames_per_sequence > 0
|
246 |
+
)
|
247 |
+
|
248 |
+
def seq_frame_index_to_dataset_index(
|
249 |
+
self,
|
250 |
+
seq_frame_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
|
251 |
+
allow_missing_indices: bool = False,
|
252 |
+
remove_missing_indices: bool = False,
|
253 |
+
suppress_missing_index_warning: bool = True,
|
254 |
+
) -> List[List[Union[Optional[int], int]]]:
|
255 |
+
"""
|
256 |
+
Obtain indices into the dataset object given a list of frame ids.
|
257 |
+
|
258 |
+
Args:
|
259 |
+
seq_frame_index: The list of frame ids specified as
|
260 |
+
`List[List[Tuple[sequence_name:str, frame_number:int]]]`. Optionally,
|
261 |
+
Image paths relative to the dataset_root can be stored specified as well:
|
262 |
+
`List[List[Tuple[sequence_name:str, frame_number:int, image_path:str]]]`
|
263 |
+
allow_missing_indices: If `False`, throws an IndexError upon reaching the first
|
264 |
+
entry from `seq_frame_index` which is missing in the dataset.
|
265 |
+
Otherwise, depending on `remove_missing_indices`, either returns `None`
|
266 |
+
in place of missing entries or removes the indices of missing entries.
|
267 |
+
remove_missing_indices: Active when `allow_missing_indices=True`.
|
268 |
+
If `False`, returns `None` in place of `seq_frame_index` entries that
|
269 |
+
are not present in the dataset.
|
270 |
+
If `True` removes missing indices from the returned indices.
|
271 |
+
suppress_missing_index_warning:
|
272 |
+
Active if `allow_missing_indices==True`. Suppressess a warning message
|
273 |
+
in case an entry from `seq_frame_index` is missing in the dataset
|
274 |
+
(expected in certain cases - e.g. when setting
|
275 |
+
`self.remove_empty_masks=True`).
|
276 |
+
|
277 |
+
Returns:
|
278 |
+
dataset_idx: Indices of dataset entries corresponding to`seq_frame_index`.
|
279 |
+
"""
|
280 |
+
_dataset_seq_frame_n_index = {
|
281 |
+
seq: {
|
282 |
+
# pyre-ignore[16]
|
283 |
+
self.frame_annots[idx]["frame_annotation"].frame_number: idx
|
284 |
+
for idx in seq_idx
|
285 |
+
}
|
286 |
+
# pyre-ignore[16]
|
287 |
+
for seq, seq_idx in self._seq_to_idx.items()
|
288 |
+
}
|
289 |
+
|
290 |
+
def _get_dataset_idx(
|
291 |
+
seq_name: str, frame_no: int, path: Optional[str] = None
|
292 |
+
) -> Optional[int]:
|
293 |
+
idx_seq = _dataset_seq_frame_n_index.get(seq_name, None)
|
294 |
+
idx = idx_seq.get(frame_no, None) if idx_seq is not None else None
|
295 |
+
if idx is None:
|
296 |
+
msg = (
|
297 |
+
f"sequence_name={seq_name} / frame_number={frame_no}"
|
298 |
+
" not in the dataset!"
|
299 |
+
)
|
300 |
+
if not allow_missing_indices:
|
301 |
+
raise IndexError(msg)
|
302 |
+
if not suppress_missing_index_warning:
|
303 |
+
warnings.warn(msg)
|
304 |
+
return idx
|
305 |
+
if path is not None:
|
306 |
+
# Check that the loaded frame path is consistent
|
307 |
+
# with the one stored in self.frame_annots.
|
308 |
+
assert os.path.normpath(
|
309 |
+
# pyre-ignore[16]
|
310 |
+
self.frame_annots[idx]["frame_annotation"].image.path
|
311 |
+
) == os.path.normpath(
|
312 |
+
path
|
313 |
+
), f"Inconsistent frame indices {seq_name, frame_no, path}."
|
314 |
+
return idx
|
315 |
+
|
316 |
+
dataset_idx = [
|
317 |
+
[_get_dataset_idx(*b) for b in batch] # pyre-ignore [6]
|
318 |
+
for batch in seq_frame_index
|
319 |
+
]
|
320 |
+
|
321 |
+
if allow_missing_indices and remove_missing_indices:
|
322 |
+
# remove all None indices, and also batches with only None entries
|
323 |
+
valid_dataset_idx = [
|
324 |
+
[b for b in batch if b is not None] for batch in dataset_idx
|
325 |
+
]
|
326 |
+
return [ # pyre-ignore[7]
|
327 |
+
batch for batch in valid_dataset_idx if len(batch) > 0
|
328 |
+
]
|
329 |
+
|
330 |
+
return dataset_idx
|
331 |
+
|
332 |
+
def subset_from_frame_index(
|
333 |
+
self,
|
334 |
+
frame_index: List[Union[Tuple[str, int], Tuple[str, int, str]]],
|
335 |
+
allow_missing_indices: bool = True,
|
336 |
+
) -> "JsonIndexDataset":
|
337 |
+
"""
|
338 |
+
Generate a dataset subset given the list of frames specified in `frame_index`.
|
339 |
+
|
340 |
+
Args:
|
341 |
+
frame_index: The list of frame indentifiers (as stored in the metadata)
|
342 |
+
specified as `List[Tuple[sequence_name:str, frame_number:int]]`. Optionally,
|
343 |
+
Image paths relative to the dataset_root can be stored specified as well:
|
344 |
+
`List[Tuple[sequence_name:str, frame_number:int, image_path:str]]`,
|
345 |
+
in the latter case, if imaga_path do not match the stored paths, an error
|
346 |
+
is raised.
|
347 |
+
allow_missing_indices: If `False`, throws an IndexError upon reaching the first
|
348 |
+
entry from `frame_index` which is missing in the dataset.
|
349 |
+
Otherwise, generates a subset consisting of frames entries that actually
|
350 |
+
exist in the dataset.
|
351 |
+
"""
|
352 |
+
# Get the indices into the frame annots.
|
353 |
+
dataset_indices = self.seq_frame_index_to_dataset_index(
|
354 |
+
[frame_index],
|
355 |
+
allow_missing_indices=self.is_filtered() and allow_missing_indices,
|
356 |
+
)[0]
|
357 |
+
valid_dataset_indices = [i for i in dataset_indices if i is not None]
|
358 |
+
|
359 |
+
# Deep copy the whole dataset except frame_annots, which are large so we
|
360 |
+
# deep copy only the requested subset of frame_annots.
|
361 |
+
memo = {id(self.frame_annots): None} # pyre-ignore[16]
|
362 |
+
dataset_new = copy.deepcopy(self, memo)
|
363 |
+
dataset_new.frame_annots = copy.deepcopy(
|
364 |
+
[self.frame_annots[i] for i in valid_dataset_indices]
|
365 |
+
)
|
366 |
+
|
367 |
+
# This will kill all unneeded sequence annotations.
|
368 |
+
dataset_new._invalidate_indexes(filter_seq_annots=True)
|
369 |
+
|
370 |
+
# Finally annotate the frame annotations with the name of the subset
|
371 |
+
# stored in meta.
|
372 |
+
for frame_annot in dataset_new.frame_annots:
|
373 |
+
frame_annotation = frame_annot["frame_annotation"]
|
374 |
+
if frame_annotation.meta is not None:
|
375 |
+
frame_annot["subset"] = frame_annotation.meta.get("frame_type", None)
|
376 |
+
|
377 |
+
# A sanity check - this will crash in case some entries from frame_index are missing
|
378 |
+
# in dataset_new.
|
379 |
+
valid_frame_index = [
|
380 |
+
fi for fi, di in zip(frame_index, dataset_indices) if di is not None
|
381 |
+
]
|
382 |
+
dataset_new.seq_frame_index_to_dataset_index(
|
383 |
+
[valid_frame_index], allow_missing_indices=False
|
384 |
+
)
|
385 |
+
|
386 |
+
return dataset_new
|
387 |
+
|
388 |
+
def __str__(self) -> str:
|
389 |
+
# pyre-ignore[16]
|
390 |
+
return f"JsonIndexDataset #frames={len(self.frame_annots)}"
|
391 |
+
|
392 |
+
def __len__(self) -> int:
|
393 |
+
# pyre-ignore[16]
|
394 |
+
return len(self.frame_annots)
|
395 |
+
|
396 |
+
def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]:
|
397 |
+
return entry["subset"]
|
398 |
+
|
399 |
+
def get_all_train_cameras(self) -> CamerasBase:
|
400 |
+
"""
|
401 |
+
Returns the cameras corresponding to all the known frames.
|
402 |
+
"""
|
403 |
+
logger.info("Loading all train cameras.")
|
404 |
+
cameras = []
|
405 |
+
# pyre-ignore[16]
|
406 |
+
for frame_idx, frame_annot in enumerate(tqdm(self.frame_annots)):
|
407 |
+
frame_type = self._get_frame_type(frame_annot)
|
408 |
+
if frame_type is None:
|
409 |
+
raise ValueError("subsets not loaded")
|
410 |
+
if is_known_frame_scalar(frame_type):
|
411 |
+
cameras.append(self[frame_idx].camera)
|
412 |
+
return join_cameras_as_batch(cameras)
|
413 |
+
|
414 |
+
def __getitem__(self, index) -> FrameData:
|
415 |
+
# pyre-ignore[16]
|
416 |
+
if index >= len(self.frame_annots):
|
417 |
+
raise IndexError(f"index {index} out of range {len(self.frame_annots)}")
|
418 |
+
|
419 |
+
entry = self.frame_annots[index]["frame_annotation"]
|
420 |
+
# pyre-ignore[16]
|
421 |
+
point_cloud = self.seq_annots[entry.sequence_name].point_cloud
|
422 |
+
frame_data = FrameData(
|
423 |
+
frame_number=_safe_as_tensor(entry.frame_number, torch.long),
|
424 |
+
frame_timestamp=_safe_as_tensor(entry.frame_timestamp, torch.float),
|
425 |
+
sequence_name=entry.sequence_name,
|
426 |
+
sequence_category=self.seq_annots[entry.sequence_name].category,
|
427 |
+
camera_quality_score=_safe_as_tensor(
|
428 |
+
self.seq_annots[entry.sequence_name].viewpoint_quality_score,
|
429 |
+
torch.float,
|
430 |
+
),
|
431 |
+
point_cloud_quality_score=_safe_as_tensor(
|
432 |
+
point_cloud.quality_score, torch.float
|
433 |
+
)
|
434 |
+
if point_cloud is not None
|
435 |
+
else None,
|
436 |
+
)
|
437 |
+
|
438 |
+
# The rest of the fields are optional
|
439 |
+
frame_data.frame_type = self._get_frame_type(self.frame_annots[index])
|
440 |
+
|
441 |
+
(
|
442 |
+
frame_data.fg_probability,
|
443 |
+
frame_data.mask_path,
|
444 |
+
frame_data.bbox_xywh,
|
445 |
+
clamp_bbox_xyxy,
|
446 |
+
frame_data.crop_bbox_xywh,
|
447 |
+
) = self._load_crop_fg_probability(entry)
|
448 |
+
|
449 |
+
scale = 1.0
|
450 |
+
if self.load_images and entry.image is not None:
|
451 |
+
# original image size
|
452 |
+
frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long)
|
453 |
+
|
454 |
+
(
|
455 |
+
frame_data.image_rgb,
|
456 |
+
frame_data.image_path,
|
457 |
+
frame_data.mask_crop,
|
458 |
+
scale,
|
459 |
+
) = self._load_crop_images(
|
460 |
+
entry, frame_data.fg_probability, clamp_bbox_xyxy
|
461 |
+
)
|
462 |
+
|
463 |
+
if self.load_depths and entry.depth is not None:
|
464 |
+
(
|
465 |
+
frame_data.depth_map,
|
466 |
+
frame_data.depth_path,
|
467 |
+
frame_data.depth_mask,
|
468 |
+
) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability)
|
469 |
+
|
470 |
+
if entry.viewpoint is not None:
|
471 |
+
frame_data.camera = self._get_pytorch3d_camera(
|
472 |
+
entry,
|
473 |
+
scale,
|
474 |
+
clamp_bbox_xyxy,
|
475 |
+
)
|
476 |
+
|
477 |
+
if self.load_point_clouds and point_cloud is not None:
|
478 |
+
pcl_path = self._fix_point_cloud_path(point_cloud.path)
|
479 |
+
frame_data.sequence_point_cloud = _load_pointcloud(
|
480 |
+
self._local_path(pcl_path), max_points=self.max_points
|
481 |
+
)
|
482 |
+
frame_data.sequence_point_cloud_path = pcl_path
|
483 |
+
|
484 |
+
return frame_data
|
485 |
+
|
486 |
+
def _fix_point_cloud_path(self, path: str) -> str:
|
487 |
+
"""
|
488 |
+
Fix up a point cloud path from the dataset.
|
489 |
+
Some files in Co3Dv2 have an accidental absolute path stored.
|
490 |
+
"""
|
491 |
+
unwanted_prefix = (
|
492 |
+
"/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/"
|
493 |
+
)
|
494 |
+
if path.startswith(unwanted_prefix):
|
495 |
+
path = path[len(unwanted_prefix) :]
|
496 |
+
return os.path.join(self.dataset_root, path)
|
497 |
+
|
498 |
+
def _load_crop_fg_probability(
|
499 |
+
self, entry: types.FrameAnnotation
|
500 |
+
) -> Tuple[
|
501 |
+
Optional[torch.Tensor],
|
502 |
+
Optional[str],
|
503 |
+
Optional[torch.Tensor],
|
504 |
+
Optional[torch.Tensor],
|
505 |
+
Optional[torch.Tensor],
|
506 |
+
]:
|
507 |
+
fg_probability = None
|
508 |
+
full_path = None
|
509 |
+
bbox_xywh = None
|
510 |
+
clamp_bbox_xyxy = None
|
511 |
+
crop_box_xywh = None
|
512 |
+
|
513 |
+
if (self.load_masks or self.box_crop) and entry.mask is not None:
|
514 |
+
full_path = os.path.join(self.dataset_root, entry.mask.path)
|
515 |
+
mask = _load_mask(self._local_path(full_path))
|
516 |
+
|
517 |
+
if mask.shape[-2:] != entry.image.size:
|
518 |
+
raise ValueError(
|
519 |
+
f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!"
|
520 |
+
)
|
521 |
+
|
522 |
+
bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr))
|
523 |
+
|
524 |
+
if self.box_crop:
|
525 |
+
clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round(
|
526 |
+
_get_clamp_bbox(
|
527 |
+
bbox_xywh,
|
528 |
+
image_path=entry.image.path,
|
529 |
+
box_crop_context=self.box_crop_context,
|
530 |
+
),
|
531 |
+
image_size_hw=tuple(mask.shape[-2:]),
|
532 |
+
)
|
533 |
+
crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy)
|
534 |
+
|
535 |
+
mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path)
|
536 |
+
|
537 |
+
fg_probability, _, _ = self._resize_image(mask, mode="nearest")
|
538 |
+
|
539 |
+
return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh
|
540 |
+
|
541 |
+
def _load_crop_images(
|
542 |
+
self,
|
543 |
+
entry: types.FrameAnnotation,
|
544 |
+
fg_probability: Optional[torch.Tensor],
|
545 |
+
clamp_bbox_xyxy: Optional[torch.Tensor],
|
546 |
+
) -> Tuple[torch.Tensor, str, torch.Tensor, float]:
|
547 |
+
assert self.dataset_root is not None and entry.image is not None
|
548 |
+
path = os.path.join(self.dataset_root, entry.image.path)
|
549 |
+
image_rgb = _load_image(self._local_path(path))
|
550 |
+
|
551 |
+
if image_rgb.shape[-2:] != entry.image.size:
|
552 |
+
raise ValueError(
|
553 |
+
f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
|
554 |
+
)
|
555 |
+
|
556 |
+
if self.box_crop:
|
557 |
+
assert clamp_bbox_xyxy is not None
|
558 |
+
image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path)
|
559 |
+
|
560 |
+
image_rgb, scale, mask_crop = self._resize_image(image_rgb)
|
561 |
+
|
562 |
+
if self.mask_images:
|
563 |
+
assert fg_probability is not None
|
564 |
+
image_rgb *= fg_probability
|
565 |
+
|
566 |
+
return image_rgb, path, mask_crop, scale
|
567 |
+
|
568 |
+
def _load_mask_depth(
|
569 |
+
self,
|
570 |
+
entry: types.FrameAnnotation,
|
571 |
+
clamp_bbox_xyxy: Optional[torch.Tensor],
|
572 |
+
fg_probability: Optional[torch.Tensor],
|
573 |
+
) -> Tuple[torch.Tensor, str, torch.Tensor]:
|
574 |
+
entry_depth = entry.depth
|
575 |
+
assert entry_depth is not None
|
576 |
+
path = os.path.join(self.dataset_root, entry_depth.path)
|
577 |
+
depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment)
|
578 |
+
|
579 |
+
if self.box_crop:
|
580 |
+
assert clamp_bbox_xyxy is not None
|
581 |
+
depth_bbox_xyxy = _rescale_bbox(
|
582 |
+
clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:]
|
583 |
+
)
|
584 |
+
depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path)
|
585 |
+
|
586 |
+
depth_map, _, _ = self._resize_image(depth_map, mode="nearest")
|
587 |
+
|
588 |
+
if self.mask_depths:
|
589 |
+
assert fg_probability is not None
|
590 |
+
depth_map *= fg_probability
|
591 |
+
|
592 |
+
if self.load_depth_masks:
|
593 |
+
assert entry_depth.mask_path is not None
|
594 |
+
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
|
595 |
+
depth_mask = _load_depth_mask(self._local_path(mask_path))
|
596 |
+
|
597 |
+
if self.box_crop:
|
598 |
+
assert clamp_bbox_xyxy is not None
|
599 |
+
depth_mask_bbox_xyxy = _rescale_bbox(
|
600 |
+
clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:]
|
601 |
+
)
|
602 |
+
depth_mask = _crop_around_box(
|
603 |
+
depth_mask, depth_mask_bbox_xyxy, mask_path
|
604 |
+
)
|
605 |
+
|
606 |
+
depth_mask, _, _ = self._resize_image(depth_mask, mode="nearest")
|
607 |
+
else:
|
608 |
+
depth_mask = torch.ones_like(depth_map)
|
609 |
+
|
610 |
+
return depth_map, path, depth_mask
|
611 |
+
|
612 |
+
def _get_pytorch3d_camera(
|
613 |
+
self,
|
614 |
+
entry: types.FrameAnnotation,
|
615 |
+
scale: float,
|
616 |
+
clamp_bbox_xyxy: Optional[torch.Tensor],
|
617 |
+
) -> PerspectiveCameras:
|
618 |
+
entry_viewpoint = entry.viewpoint
|
619 |
+
assert entry_viewpoint is not None
|
620 |
+
# principal point and focal length
|
621 |
+
principal_point = torch.tensor(
|
622 |
+
entry_viewpoint.principal_point, dtype=torch.float
|
623 |
+
)
|
624 |
+
focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)
|
625 |
+
|
626 |
+
half_image_size_wh_orig = (
|
627 |
+
torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0
|
628 |
+
)
|
629 |
+
|
630 |
+
# first, we convert from the dataset's NDC convention to pixels
|
631 |
+
format = entry_viewpoint.intrinsics_format
|
632 |
+
if format.lower() == "ndc_norm_image_bounds":
|
633 |
+
# this is e.g. currently used in CO3D for storing intrinsics
|
634 |
+
rescale = half_image_size_wh_orig
|
635 |
+
elif format.lower() == "ndc_isotropic":
|
636 |
+
rescale = half_image_size_wh_orig.min()
|
637 |
+
else:
|
638 |
+
raise ValueError(f"Unknown intrinsics format: {format}")
|
639 |
+
|
640 |
+
# principal point and focal length in pixels
|
641 |
+
principal_point_px = half_image_size_wh_orig - principal_point * rescale
|
642 |
+
focal_length_px = focal_length * rescale
|
643 |
+
if self.box_crop:
|
644 |
+
assert clamp_bbox_xyxy is not None
|
645 |
+
principal_point_px -= clamp_bbox_xyxy[:2]
|
646 |
+
|
647 |
+
# now, convert from pixels to PyTorch3D v0.5+ NDC convention
|
648 |
+
if self.image_height is None or self.image_width is None:
|
649 |
+
out_size = list(reversed(entry.image.size))
|
650 |
+
else:
|
651 |
+
out_size = [self.image_width, self.image_height]
|
652 |
+
|
653 |
+
half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0
|
654 |
+
half_min_image_size_output = half_image_size_output.min()
|
655 |
+
|
656 |
+
# rescaled principal point and focal length in ndc
|
657 |
+
principal_point = (
|
658 |
+
half_image_size_output - principal_point_px * scale
|
659 |
+
) / half_min_image_size_output
|
660 |
+
focal_length = focal_length_px * scale / half_min_image_size_output
|
661 |
+
|
662 |
+
return PerspectiveCameras(
|
663 |
+
focal_length=focal_length[None],
|
664 |
+
principal_point=principal_point[None],
|
665 |
+
R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],
|
666 |
+
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
|
667 |
+
)
|
668 |
+
|
669 |
+
def _load_frames(self) -> None:
|
670 |
+
logger.info(f"Loading Co3D frames from {self.frame_annotations_file}.")
|
671 |
+
local_file = self._local_path(self.frame_annotations_file)
|
672 |
+
with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
|
673 |
+
frame_annots_list = types.load_dataclass(
|
674 |
+
zipfile, List[self.frame_annotations_type]
|
675 |
+
)
|
676 |
+
if not frame_annots_list:
|
677 |
+
raise ValueError("Empty dataset!")
|
678 |
+
# pyre-ignore[16]
|
679 |
+
self.frame_annots = [
|
680 |
+
FrameAnnotsEntry(frame_annotation=a, subset=None) for a in frame_annots_list
|
681 |
+
]
|
682 |
+
|
683 |
+
def _load_sequences(self) -> None:
|
684 |
+
logger.info(f"Loading Co3D sequences from {self.sequence_annotations_file}.")
|
685 |
+
local_file = self._local_path(self.sequence_annotations_file)
|
686 |
+
with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
|
687 |
+
seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation])
|
688 |
+
if not seq_annots:
|
689 |
+
raise ValueError("Empty sequences file!")
|
690 |
+
# pyre-ignore[16]
|
691 |
+
self.seq_annots = {entry.sequence_name: entry for entry in seq_annots}
|
692 |
+
|
693 |
+
def _load_subset_lists(self) -> None:
|
694 |
+
logger.info(f"Loading Co3D subset lists from {self.subset_lists_file}.")
|
695 |
+
if not self.subset_lists_file:
|
696 |
+
return
|
697 |
+
|
698 |
+
with open(self._local_path(self.subset_lists_file), "r") as f:
|
699 |
+
subset_to_seq_frame = json.load(f)
|
700 |
+
|
701 |
+
frame_path_to_subset = {
|
702 |
+
path: subset
|
703 |
+
for subset, frames in subset_to_seq_frame.items()
|
704 |
+
for _, _, path in frames
|
705 |
+
}
|
706 |
+
# pyre-ignore[16]
|
707 |
+
for frame in self.frame_annots:
|
708 |
+
frame["subset"] = frame_path_to_subset.get(
|
709 |
+
frame["frame_annotation"].image.path, None
|
710 |
+
)
|
711 |
+
if frame["subset"] is None:
|
712 |
+
warnings.warn(
|
713 |
+
"Subset lists are given but don't include "
|
714 |
+
+ frame["frame_annotation"].image.path
|
715 |
+
)
|
716 |
+
|
717 |
+
def _sort_frames(self) -> None:
|
718 |
+
# Sort frames to have them grouped by sequence, ordered by timestamp
|
719 |
+
# pyre-ignore[16]
|
720 |
+
self.frame_annots = sorted(
|
721 |
+
self.frame_annots,
|
722 |
+
key=lambda f: (
|
723 |
+
f["frame_annotation"].sequence_name,
|
724 |
+
f["frame_annotation"].frame_timestamp or 0,
|
725 |
+
),
|
726 |
+
)
|
727 |
+
|
728 |
+
def _filter_db(self) -> None:
|
729 |
+
if self.remove_empty_masks:
|
730 |
+
logger.info("Removing images with empty masks.")
|
731 |
+
# pyre-ignore[16]
|
732 |
+
old_len = len(self.frame_annots)
|
733 |
+
|
734 |
+
msg = "remove_empty_masks needs every MaskAnnotation.mass to be set."
|
735 |
+
|
736 |
+
def positive_mass(frame_annot: types.FrameAnnotation) -> bool:
|
737 |
+
mask = frame_annot.mask
|
738 |
+
if mask is None:
|
739 |
+
return False
|
740 |
+
if mask.mass is None:
|
741 |
+
raise ValueError(msg)
|
742 |
+
return mask.mass > 1
|
743 |
+
|
744 |
+
self.frame_annots = [
|
745 |
+
frame
|
746 |
+
for frame in self.frame_annots
|
747 |
+
if positive_mass(frame["frame_annotation"])
|
748 |
+
]
|
749 |
+
logger.info("... filtered %d -> %d" % (old_len, len(self.frame_annots)))
|
750 |
+
|
751 |
+
# this has to be called after joining with categories!!
|
752 |
+
subsets = self.subsets
|
753 |
+
if subsets:
|
754 |
+
if not self.subset_lists_file:
|
755 |
+
raise ValueError(
|
756 |
+
"Subset filter is on but subset_lists_file was not given"
|
757 |
+
)
|
758 |
+
|
759 |
+
logger.info(f"Limiting Co3D dataset to the '{subsets}' subsets.")
|
760 |
+
|
761 |
+
# truncate the list of subsets to the valid one
|
762 |
+
self.frame_annots = [
|
763 |
+
entry for entry in self.frame_annots if entry["subset"] in subsets
|
764 |
+
]
|
765 |
+
if len(self.frame_annots) == 0:
|
766 |
+
raise ValueError(f"There are no frames in the '{subsets}' subsets!")
|
767 |
+
|
768 |
+
self._invalidate_indexes(filter_seq_annots=True)
|
769 |
+
|
770 |
+
if len(self.limit_category_to) > 0:
|
771 |
+
logger.info(f"Limiting dataset to categories: {self.limit_category_to}")
|
772 |
+
# pyre-ignore[16]
|
773 |
+
self.seq_annots = {
|
774 |
+
name: entry
|
775 |
+
for name, entry in self.seq_annots.items()
|
776 |
+
if entry.category in self.limit_category_to
|
777 |
+
}
|
778 |
+
|
779 |
+
# sequence filters
|
780 |
+
for prefix in ("pick", "exclude"):
|
781 |
+
orig_len = len(self.seq_annots)
|
782 |
+
attr = f"{prefix}_sequence"
|
783 |
+
arr = getattr(self, attr)
|
784 |
+
if len(arr) > 0:
|
785 |
+
logger.info(f"{attr}: {str(arr)}")
|
786 |
+
self.seq_annots = {
|
787 |
+
name: entry
|
788 |
+
for name, entry in self.seq_annots.items()
|
789 |
+
if (name in arr) == (prefix == "pick")
|
790 |
+
}
|
791 |
+
logger.info("... filtered %d -> %d" % (orig_len, len(self.seq_annots)))
|
792 |
+
|
793 |
+
if self.limit_sequences_to > 0:
|
794 |
+
self.seq_annots = dict(
|
795 |
+
islice(self.seq_annots.items(), self.limit_sequences_to)
|
796 |
+
)
|
797 |
+
|
798 |
+
# retain only frames from retained sequences
|
799 |
+
self.frame_annots = [
|
800 |
+
f
|
801 |
+
for f in self.frame_annots
|
802 |
+
if f["frame_annotation"].sequence_name in self.seq_annots
|
803 |
+
]
|
804 |
+
|
805 |
+
self._invalidate_indexes()
|
806 |
+
|
807 |
+
if self.n_frames_per_sequence > 0:
|
808 |
+
logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.")
|
809 |
+
keep_idx = []
|
810 |
+
# pyre-ignore[16]
|
811 |
+
for seq, seq_indices in self._seq_to_idx.items():
|
812 |
+
# infer the seed from the sequence name, this is reproducible
|
813 |
+
# and makes the selection differ for different sequences
|
814 |
+
seed = _seq_name_to_seed(seq) + self.seed
|
815 |
+
seq_idx_shuffled = random.Random(seed).sample(
|
816 |
+
sorted(seq_indices), len(seq_indices)
|
817 |
+
)
|
818 |
+
keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence])
|
819 |
+
|
820 |
+
logger.info(
|
821 |
+
"... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx))
|
822 |
+
)
|
823 |
+
self.frame_annots = [self.frame_annots[i] for i in keep_idx]
|
824 |
+
self._invalidate_indexes(filter_seq_annots=False)
|
825 |
+
# sequences are not decimated, so self.seq_annots is valid
|
826 |
+
|
827 |
+
if self.limit_to > 0 and self.limit_to < len(self.frame_annots):
|
828 |
+
logger.info(
|
829 |
+
"limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to)
|
830 |
+
)
|
831 |
+
self.frame_annots = self.frame_annots[: self.limit_to]
|
832 |
+
self._invalidate_indexes(filter_seq_annots=True)
|
833 |
+
|
834 |
+
def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None:
|
835 |
+
# update _seq_to_idx and filter seq_meta according to frame_annots change
|
836 |
+
# if filter_seq_annots, also uldates seq_annots based on the changed _seq_to_idx
|
837 |
+
self._invalidate_seq_to_idx()
|
838 |
+
|
839 |
+
if filter_seq_annots:
|
840 |
+
# pyre-ignore[16]
|
841 |
+
self.seq_annots = {
|
842 |
+
k: v
|
843 |
+
for k, v in self.seq_annots.items()
|
844 |
+
# pyre-ignore[16]
|
845 |
+
if k in self._seq_to_idx
|
846 |
+
}
|
847 |
+
|
848 |
+
def _invalidate_seq_to_idx(self) -> None:
|
849 |
+
seq_to_idx = defaultdict(list)
|
850 |
+
# pyre-ignore[16]
|
851 |
+
for idx, entry in enumerate(self.frame_annots):
|
852 |
+
seq_to_idx[entry["frame_annotation"].sequence_name].append(idx)
|
853 |
+
# pyre-ignore[16]
|
854 |
+
self._seq_to_idx = seq_to_idx
|
855 |
+
|
856 |
+
def _resize_image(
|
857 |
+
self, image, mode="bilinear"
|
858 |
+
) -> Tuple[torch.Tensor, float, torch.Tensor]:
|
859 |
+
image_height, image_width = self.image_height, self.image_width
|
860 |
+
if image_height is None or image_width is None:
|
861 |
+
# skip the resizing
|
862 |
+
imre_ = torch.from_numpy(image)
|
863 |
+
return imre_, 1.0, torch.ones_like(imre_[:1])
|
864 |
+
# takes numpy array, returns pytorch tensor
|
865 |
+
minscale = min(
|
866 |
+
image_height / image.shape[-2],
|
867 |
+
image_width / image.shape[-1],
|
868 |
+
)
|
869 |
+
imre = torch.nn.functional.interpolate(
|
870 |
+
torch.from_numpy(image)[None],
|
871 |
+
scale_factor=minscale,
|
872 |
+
mode=mode,
|
873 |
+
align_corners=False if mode == "bilinear" else None,
|
874 |
+
recompute_scale_factor=True,
|
875 |
+
)[0]
|
876 |
+
# pyre-fixme[19]: Expected 1 positional argument.
|
877 |
+
imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width)
|
878 |
+
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
|
879 |
+
# pyre-fixme[6]: For 2nd param expected `int` but got `Optional[int]`.
|
880 |
+
# pyre-fixme[6]: For 3rd param expected `int` but got `Optional[int]`.
|
881 |
+
mask = torch.zeros(1, self.image_height, self.image_width)
|
882 |
+
mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0
|
883 |
+
return imre_, minscale, mask
|
884 |
+
|
885 |
+
def _local_path(self, path: str) -> str:
|
886 |
+
if self.path_manager is None:
|
887 |
+
return path
|
888 |
+
return self.path_manager.get_local_path(path)
|
889 |
+
|
890 |
+
def get_frame_numbers_and_timestamps(
|
891 |
+
self, idxs: Sequence[int]
|
892 |
+
) -> List[Tuple[int, float]]:
|
893 |
+
out: List[Tuple[int, float]] = []
|
894 |
+
for idx in idxs:
|
895 |
+
# pyre-ignore[16]
|
896 |
+
frame_annotation = self.frame_annots[idx]["frame_annotation"]
|
897 |
+
out.append(
|
898 |
+
(frame_annotation.frame_number, frame_annotation.frame_timestamp)
|
899 |
+
)
|
900 |
+
return out
|
901 |
+
|
902 |
+
def category_to_sequence_names(self) -> Dict[str, List[str]]:
|
903 |
+
c2seq = defaultdict(list)
|
904 |
+
# pyre-ignore
|
905 |
+
for sequence_name, sa in self.seq_annots.items():
|
906 |
+
c2seq[sa.category].append(sequence_name)
|
907 |
+
return dict(c2seq)
|
908 |
+
|
909 |
+
def get_eval_batches(self) -> Optional[List[List[int]]]:
|
910 |
+
return self.eval_batches
|
911 |
+
|
912 |
+
|
913 |
+
def _seq_name_to_seed(seq_name) -> int:
|
914 |
+
return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest(), 16)
|
915 |
+
|
916 |
+
|
917 |
+
def _load_image(path) -> np.ndarray:
|
918 |
+
with Image.open(path) as pil_im:
|
919 |
+
im = np.array(pil_im.convert("RGB"))
|
920 |
+
im = im.transpose((2, 0, 1))
|
921 |
+
im = im.astype(np.float32) / 255.0
|
922 |
+
return im
|
923 |
+
|
924 |
+
|
925 |
+
def _load_16big_png_depth(depth_png) -> np.ndarray:
|
926 |
+
with Image.open(depth_png) as depth_pil:
|
927 |
+
# the image is stored with 16-bit depth but PIL reads it as I (32 bit).
|
928 |
+
# we cast it to uint16, then reinterpret as float16, then cast to float32
|
929 |
+
depth = (
|
930 |
+
np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
|
931 |
+
.astype(np.float32)
|
932 |
+
.reshape((depth_pil.size[1], depth_pil.size[0]))
|
933 |
+
)
|
934 |
+
return depth
|
935 |
+
|
936 |
+
|
937 |
+
def _load_1bit_png_mask(file: str) -> np.ndarray:
|
938 |
+
with Image.open(file) as pil_im:
|
939 |
+
mask = (np.array(pil_im.convert("L")) > 0.0).astype(np.float32)
|
940 |
+
return mask
|
941 |
+
|
942 |
+
|
943 |
+
def _load_depth_mask(path: str) -> np.ndarray:
|
944 |
+
if not path.lower().endswith(".png"):
|
945 |
+
raise ValueError('unsupported depth mask file name "%s"' % path)
|
946 |
+
m = _load_1bit_png_mask(path)
|
947 |
+
return m[None] # fake feature channel
|
948 |
+
|
949 |
+
|
950 |
+
def _load_depth(path, scale_adjustment) -> np.ndarray:
|
951 |
+
if not path.lower().endswith(".png"):
|
952 |
+
raise ValueError('unsupported depth file name "%s"' % path)
|
953 |
+
|
954 |
+
d = _load_16big_png_depth(path) * scale_adjustment
|
955 |
+
d[~np.isfinite(d)] = 0.0
|
956 |
+
return d[None] # fake feature channel
|
957 |
+
|
958 |
+
|
959 |
+
def _load_mask(path) -> np.ndarray:
|
960 |
+
with Image.open(path) as pil_im:
|
961 |
+
mask = np.array(pil_im)
|
962 |
+
mask = mask.astype(np.float32) / 255.0
|
963 |
+
return mask[None] # fake feature channel
|
964 |
+
|
965 |
+
|
966 |
+
def _get_1d_bounds(arr) -> Tuple[int, int]:
|
967 |
+
nz = np.flatnonzero(arr)
|
968 |
+
return nz[0], nz[-1] + 1
|
969 |
+
|
970 |
+
|
971 |
+
def _get_bbox_from_mask(
|
972 |
+
mask, thr, decrease_quant: float = 0.05
|
973 |
+
) -> Tuple[int, int, int, int]:
|
974 |
+
# bbox in xywh
|
975 |
+
masks_for_box = np.zeros_like(mask)
|
976 |
+
while masks_for_box.sum() <= 1.0:
|
977 |
+
masks_for_box = (mask > thr).astype(np.float32)
|
978 |
+
thr -= decrease_quant
|
979 |
+
if thr <= 0.0:
|
980 |
+
warnings.warn(f"Empty masks_for_bbox (thr={thr}) => using full image.")
|
981 |
+
|
982 |
+
x0, x1 = _get_1d_bounds(masks_for_box.sum(axis=-2))
|
983 |
+
y0, y1 = _get_1d_bounds(masks_for_box.sum(axis=-1))
|
984 |
+
|
985 |
+
return x0, y0, x1 - x0, y1 - y0
|
986 |
+
|
987 |
+
|
988 |
+
def _get_clamp_bbox(
|
989 |
+
bbox: torch.Tensor,
|
990 |
+
box_crop_context: float = 0.0,
|
991 |
+
image_path: str = "",
|
992 |
+
) -> torch.Tensor:
|
993 |
+
# box_crop_context: rate of expansion for bbox
|
994 |
+
# returns possibly expanded bbox xyxy as float
|
995 |
+
|
996 |
+
bbox = bbox.clone() # do not edit bbox in place
|
997 |
+
|
998 |
+
# increase box size
|
999 |
+
if box_crop_context > 0.0:
|
1000 |
+
c = box_crop_context
|
1001 |
+
bbox = bbox.float()
|
1002 |
+
bbox[0] -= bbox[2] * c / 2
|
1003 |
+
bbox[1] -= bbox[3] * c / 2
|
1004 |
+
bbox[2] += bbox[2] * c
|
1005 |
+
bbox[3] += bbox[3] * c
|
1006 |
+
|
1007 |
+
if (bbox[2:] <= 1.0).any():
|
1008 |
+
raise ValueError(
|
1009 |
+
f"squashed image {image_path}!! The bounding box contains no pixels."
|
1010 |
+
)
|
1011 |
+
|
1012 |
+
bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes
|
1013 |
+
bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2)
|
1014 |
+
|
1015 |
+
return bbox_xyxy
|
1016 |
+
|
1017 |
+
|
1018 |
+
def _crop_around_box(tensor, bbox, impath: str = ""):
|
1019 |
+
# bbox is xyxy, where the upper bound is corrected with +1
|
1020 |
+
bbox = _clamp_box_to_image_bounds_and_round(
|
1021 |
+
bbox,
|
1022 |
+
image_size_hw=tensor.shape[-2:],
|
1023 |
+
)
|
1024 |
+
tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]]
|
1025 |
+
assert all(c > 0 for c in tensor.shape), f"squashed image {impath}"
|
1026 |
+
return tensor
|
1027 |
+
|
1028 |
+
|
1029 |
+
def _clamp_box_to_image_bounds_and_round(
|
1030 |
+
bbox_xyxy: torch.Tensor,
|
1031 |
+
image_size_hw: Tuple[int, int],
|
1032 |
+
) -> torch.LongTensor:
|
1033 |
+
bbox_xyxy = bbox_xyxy.clone()
|
1034 |
+
bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1])
|
1035 |
+
bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2])
|
1036 |
+
if not isinstance(bbox_xyxy, torch.LongTensor):
|
1037 |
+
bbox_xyxy = bbox_xyxy.round().long()
|
1038 |
+
return bbox_xyxy # pyre-ignore [7]
|
1039 |
+
|
1040 |
+
|
1041 |
+
def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor:
|
1042 |
+
assert bbox is not None
|
1043 |
+
assert np.prod(orig_res) > 1e-8
|
1044 |
+
# average ratio of dimensions
|
1045 |
+
rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0
|
1046 |
+
return bbox * rel_size
|
1047 |
+
|
1048 |
+
|
1049 |
+
def _bbox_xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor:
|
1050 |
+
wh = xyxy[2:] - xyxy[:2]
|
1051 |
+
xywh = torch.cat([xyxy[:2], wh])
|
1052 |
+
return xywh
|
1053 |
+
|
1054 |
+
|
1055 |
+
def _bbox_xywh_to_xyxy(
|
1056 |
+
xywh: torch.Tensor, clamp_size: Optional[int] = None
|
1057 |
+
) -> torch.Tensor:
|
1058 |
+
xyxy = xywh.clone()
|
1059 |
+
if clamp_size is not None:
|
1060 |
+
xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
|
1061 |
+
xyxy[2:] += xyxy[:2]
|
1062 |
+
return xyxy
|
1063 |
+
|
1064 |
+
|
1065 |
+
def _safe_as_tensor(data, dtype):
|
1066 |
+
if data is None:
|
1067 |
+
return None
|
1068 |
+
return torch.tensor(data, dtype=dtype)
|
1069 |
+
|
1070 |
+
|
1071 |
+
# NOTE this cache is per-worker; they are implemented as processes.
|
1072 |
+
# each batch is loaded and collated by a single worker;
|
1073 |
+
# since sequences tend to co-occur within batches, this is useful.
|
1074 |
+
@functools.lru_cache(maxsize=256)
|
1075 |
+
def _load_pointcloud(pcl_path: Union[str, Path], max_points: int = 0) -> Pointclouds:
|
1076 |
+
pcl = IO().load_pointcloud(pcl_path)
|
1077 |
+
if max_points > 0:
|
1078 |
+
pcl = pcl.subsample(max_points)
|
1079 |
+
|
1080 |
+
return pcl
|
sgm/data/latent_objaverse.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from pathlib import Path
|
3 |
+
from PIL import Image
|
4 |
+
import json
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import Dataset, DataLoader, default_collate
|
7 |
+
from torchvision.transforms import ToTensor, Normalize, Compose, Resize
|
8 |
+
from pytorch_lightning import LightningDataModule
|
9 |
+
from einops import rearrange
|
10 |
+
|
11 |
+
|
12 |
+
class LatentObjaverseSpiral(Dataset):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
root_dir,
|
16 |
+
split="train",
|
17 |
+
transform=None,
|
18 |
+
random_front=False,
|
19 |
+
max_item=None,
|
20 |
+
cond_aug_mean=-3.0,
|
21 |
+
cond_aug_std=0.5,
|
22 |
+
condition_on_elevation=False,
|
23 |
+
**unused_kwargs,
|
24 |
+
):
|
25 |
+
print("Using LVIS subset with precomputed Latents")
|
26 |
+
self.root_dir = Path(root_dir)
|
27 |
+
self.split = split
|
28 |
+
self.random_front = random_front
|
29 |
+
self.transform = transform
|
30 |
+
|
31 |
+
self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512")
|
32 |
+
|
33 |
+
self.ids = json.load(open("./assets/lvis_uids.json", "r"))
|
34 |
+
self.n_views = 18
|
35 |
+
valid_ids = []
|
36 |
+
for idx in self.ids:
|
37 |
+
if (self.root_dir / idx).exists():
|
38 |
+
valid_ids.append(idx)
|
39 |
+
self.ids = valid_ids
|
40 |
+
print("=" * 30)
|
41 |
+
print("Number of valid ids: ", len(self.ids))
|
42 |
+
print("=" * 30)
|
43 |
+
|
44 |
+
self.cond_aug_mean = cond_aug_mean
|
45 |
+
self.cond_aug_std = cond_aug_std
|
46 |
+
self.condition_on_elevation = condition_on_elevation
|
47 |
+
|
48 |
+
if max_item is not None:
|
49 |
+
self.ids = self.ids[:max_item]
|
50 |
+
|
51 |
+
## debug
|
52 |
+
self.ids = self.ids * 10000
|
sgm/data/mnist.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torchvision
|
3 |
+
from torch.utils.data import DataLoader, Dataset
|
4 |
+
from torchvision import transforms
|
5 |
+
|
6 |
+
|
7 |
+
class MNISTDataDictWrapper(Dataset):
|
8 |
+
def __init__(self, dset):
|
9 |
+
super().__init__()
|
10 |
+
self.dset = dset
|
11 |
+
|
12 |
+
def __getitem__(self, i):
|
13 |
+
x, y = self.dset[i]
|
14 |
+
return {"jpg": x, "cls": y}
|
15 |
+
|
16 |
+
def __len__(self):
|
17 |
+
return len(self.dset)
|
18 |
+
|
19 |
+
|
20 |
+
class MNISTLoader(pl.LightningDataModule):
|
21 |
+
def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True):
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
transform = transforms.Compose(
|
25 |
+
[transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
|
26 |
+
)
|
27 |
+
|
28 |
+
self.batch_size = batch_size
|
29 |
+
self.num_workers = num_workers
|
30 |
+
self.prefetch_factor = prefetch_factor if num_workers > 0 else 0
|
31 |
+
self.shuffle = shuffle
|
32 |
+
self.train_dataset = MNISTDataDictWrapper(
|
33 |
+
torchvision.datasets.MNIST(
|
34 |
+
root=".data/", train=True, download=True, transform=transform
|
35 |
+
)
|
36 |
+
)
|
37 |
+
self.test_dataset = MNISTDataDictWrapper(
|
38 |
+
torchvision.datasets.MNIST(
|
39 |
+
root=".data/", train=False, download=True, transform=transform
|
40 |
+
)
|
41 |
+
)
|
42 |
+
|
43 |
+
def prepare_data(self):
|
44 |
+
pass
|
45 |
+
|
46 |
+
def train_dataloader(self):
|
47 |
+
return DataLoader(
|
48 |
+
self.train_dataset,
|
49 |
+
batch_size=self.batch_size,
|
50 |
+
shuffle=self.shuffle,
|
51 |
+
num_workers=self.num_workers,
|
52 |
+
prefetch_factor=self.prefetch_factor,
|
53 |
+
)
|
54 |
+
|
55 |
+
def test_dataloader(self):
|
56 |
+
return DataLoader(
|
57 |
+
self.test_dataset,
|
58 |
+
batch_size=self.batch_size,
|
59 |
+
shuffle=self.shuffle,
|
60 |
+
num_workers=self.num_workers,
|
61 |
+
prefetch_factor=self.prefetch_factor,
|
62 |
+
)
|
63 |
+
|
64 |
+
def val_dataloader(self):
|
65 |
+
return DataLoader(
|
66 |
+
self.test_dataset,
|
67 |
+
batch_size=self.batch_size,
|
68 |
+
shuffle=self.shuffle,
|
69 |
+
num_workers=self.num_workers,
|
70 |
+
prefetch_factor=self.prefetch_factor,
|
71 |
+
)
|
72 |
+
|
73 |
+
|
74 |
+
if __name__ == "__main__":
|
75 |
+
dset = MNISTDataDictWrapper(
|
76 |
+
torchvision.datasets.MNIST(
|
77 |
+
root=".data/",
|
78 |
+
train=False,
|
79 |
+
download=True,
|
80 |
+
transform=transforms.Compose(
|
81 |
+
[transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
|
82 |
+
),
|
83 |
+
)
|
84 |
+
)
|
85 |
+
ex = dset[0]
|
sgm/data/mvimagenet.py
ADDED
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import Dataset, DataLoader, default_collate
|
4 |
+
from pathlib import Path
|
5 |
+
from PIL import Image
|
6 |
+
from scipy.spatial.transform import Rotation
|
7 |
+
import rembg
|
8 |
+
from rembg import remove, new_session
|
9 |
+
from einops import rearrange
|
10 |
+
|
11 |
+
from torchvision.transforms import ToTensor, Normalize, Compose, Resize
|
12 |
+
from torchvision.transforms.functional import to_tensor
|
13 |
+
from pytorch_lightning import LightningDataModule
|
14 |
+
|
15 |
+
from sgm.data.colmap import read_cameras_binary, read_images_binary
|
16 |
+
from sgm.data.objaverse import video_collate_fn, FLATTEN_FIELDS, flatten_for_video
|
17 |
+
|
18 |
+
|
19 |
+
def qvec2rotmat(qvec):
|
20 |
+
return np.array(
|
21 |
+
[
|
22 |
+
[
|
23 |
+
1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
|
24 |
+
2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
|
25 |
+
2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],
|
26 |
+
],
|
27 |
+
[
|
28 |
+
2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
|
29 |
+
1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
|
30 |
+
2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],
|
31 |
+
],
|
32 |
+
[
|
33 |
+
2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
|
34 |
+
2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
|
35 |
+
1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,
|
36 |
+
],
|
37 |
+
]
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
def qt2c2w(q, t):
|
42 |
+
# NOTE: remember to convert to opengl coordinate system
|
43 |
+
# rot = Rotation.from_quat(q).as_matrix()
|
44 |
+
rot = qvec2rotmat(q)
|
45 |
+
c2w = np.eye(4)
|
46 |
+
c2w[:3, :3] = np.transpose(rot)
|
47 |
+
c2w[:3, 3] = -np.transpose(rot) @ t
|
48 |
+
c2w[..., 1:3] *= -1
|
49 |
+
return c2w
|
50 |
+
|
51 |
+
|
52 |
+
def random_crop():
|
53 |
+
pass
|
54 |
+
|
55 |
+
|
56 |
+
class MVImageNet(Dataset):
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
root_dir,
|
60 |
+
split,
|
61 |
+
transform,
|
62 |
+
reso: int = 256,
|
63 |
+
mask_type: str = "random",
|
64 |
+
cond_aug_mean=-3.0,
|
65 |
+
cond_aug_std=0.5,
|
66 |
+
condition_on_elevation=False,
|
67 |
+
fps_id=0.0,
|
68 |
+
motion_bucket_id=300.0,
|
69 |
+
num_frames: int = 24,
|
70 |
+
use_mask: bool = True,
|
71 |
+
load_pixelnerf: bool = False,
|
72 |
+
scale_pose: bool = False,
|
73 |
+
max_n_cond: int = 1,
|
74 |
+
min_n_cond: int = 1,
|
75 |
+
cond_on_multi: bool = False,
|
76 |
+
) -> None:
|
77 |
+
super().__init__()
|
78 |
+
|
79 |
+
self.root_dir = Path(root_dir)
|
80 |
+
self.split = split
|
81 |
+
|
82 |
+
avails = self.root_dir.glob("*/*")
|
83 |
+
self.ids = list(
|
84 |
+
map(
|
85 |
+
lambda x: str(x.relative_to(self.root_dir)),
|
86 |
+
filter(lambda x: x.is_dir(), avails),
|
87 |
+
)
|
88 |
+
)
|
89 |
+
|
90 |
+
self.transform = transform
|
91 |
+
self.reso = reso
|
92 |
+
self.num_frames = num_frames
|
93 |
+
self.cond_aug_mean = cond_aug_mean
|
94 |
+
self.cond_aug_std = cond_aug_std
|
95 |
+
self.condition_on_elevation = condition_on_elevation
|
96 |
+
self.fps_id = fps_id
|
97 |
+
self.motion_bucket_id = motion_bucket_id
|
98 |
+
self.mask_type = mask_type
|
99 |
+
self.use_mask = use_mask
|
100 |
+
self.load_pixelnerf = load_pixelnerf
|
101 |
+
self.scale_pose = scale_pose
|
102 |
+
self.max_n_cond = max_n_cond
|
103 |
+
self.min_n_cond = min_n_cond
|
104 |
+
self.cond_on_multi = cond_on_multi
|
105 |
+
|
106 |
+
if self.cond_on_multi:
|
107 |
+
assert self.min_n_cond == self.max_n_cond
|
108 |
+
self.session = new_session()
|
109 |
+
|
110 |
+
def __getitem__(self, index: int):
|
111 |
+
# mvimgnet starts with idx==1
|
112 |
+
idx_list = np.arange(0, self.num_frames)
|
113 |
+
this_image_dir = self.root_dir / self.ids[index] / "images"
|
114 |
+
this_camera_dir = self.root_dir / self.ids[index] / "sparse/0"
|
115 |
+
|
116 |
+
# while not this_camera_dir.exists():
|
117 |
+
# index = (index + 1) % len(self.ids)
|
118 |
+
# this_image_dir = self.root_dir / self.ids[index] / "images"
|
119 |
+
# this_camera_dir = self.root_dir / self.ids[index] / "sparse/0"
|
120 |
+
if not this_camera_dir.exists():
|
121 |
+
index = 0
|
122 |
+
this_image_dir = self.root_dir / self.ids[index] / "images"
|
123 |
+
this_camera_dir = self.root_dir / self.ids[index] / "sparse/0"
|
124 |
+
|
125 |
+
this_images = read_images_binary(this_camera_dir / "images.bin")
|
126 |
+
# filenames = list(map(lambda x: f"{x:03d}", this_images.keys()))
|
127 |
+
filenames = list(this_images.keys())
|
128 |
+
|
129 |
+
if len(filenames) == 0:
|
130 |
+
index = 0
|
131 |
+
this_image_dir = self.root_dir / self.ids[index] / "images"
|
132 |
+
this_camera_dir = self.root_dir / self.ids[index] / "sparse/0"
|
133 |
+
this_images = read_images_binary(this_camera_dir / "images.bin")
|
134 |
+
# filenames = list(map(lambda x: f"{x:03d}", this_images.keys()))
|
135 |
+
filenames = list(this_images.keys())
|
136 |
+
|
137 |
+
filenames = list(
|
138 |
+
filter(lambda x: (this_image_dir / this_images[x].name).exists(), filenames)
|
139 |
+
)
|
140 |
+
|
141 |
+
filenames = sorted(filenames, key=lambda x: this_images[x].name)
|
142 |
+
|
143 |
+
# # debug
|
144 |
+
# names = []
|
145 |
+
# for v in filenames:
|
146 |
+
# names.append(this_images[v].name)
|
147 |
+
# breakpoint()
|
148 |
+
|
149 |
+
while len(filenames) < self.num_frames:
|
150 |
+
num_surpass = self.num_frames - len(filenames)
|
151 |
+
filenames += list(reversed(filenames[-num_surpass:]))
|
152 |
+
|
153 |
+
if len(filenames) < self.num_frames:
|
154 |
+
print(f"\n\n{self.ids[index]}\n\n")
|
155 |
+
|
156 |
+
frames = []
|
157 |
+
cameras = []
|
158 |
+
downsampled_rgb = []
|
159 |
+
for view_idx in idx_list:
|
160 |
+
this_id = filenames[view_idx]
|
161 |
+
frame = Image.open(this_image_dir / this_images[this_id].name)
|
162 |
+
w, h = frame.size
|
163 |
+
|
164 |
+
if self.mask_type == "random":
|
165 |
+
image_size = min(h, w)
|
166 |
+
left = np.random.randint(0, w - image_size + 1)
|
167 |
+
right = left + image_size
|
168 |
+
top = np.random.randint(0, h - image_size + 1)
|
169 |
+
bottom = top + image_size
|
170 |
+
## need to assign left, right, top, bottom, image_size
|
171 |
+
elif self.mask_type == "object":
|
172 |
+
pass
|
173 |
+
elif self.mask_type == "rembg":
|
174 |
+
image_size = min(h, w)
|
175 |
+
if (
|
176 |
+
cached := this_image_dir
|
177 |
+
/ f"{this_images[this_id].name[:-4]}_rembg.png"
|
178 |
+
).exists():
|
179 |
+
try:
|
180 |
+
mask = np.asarray(Image.open(cached, formats=["png"]))[..., 3]
|
181 |
+
except:
|
182 |
+
mask = remove(frame, session=self.session)
|
183 |
+
mask.save(cached)
|
184 |
+
mask = np.asarray(mask)[..., 3]
|
185 |
+
else:
|
186 |
+
mask = remove(frame, session=self.session)
|
187 |
+
mask.save(cached)
|
188 |
+
mask = np.asarray(mask)[..., 3]
|
189 |
+
# in h,w order
|
190 |
+
y, x = np.array(mask.nonzero())
|
191 |
+
bbox_cx = x.mean()
|
192 |
+
bbox_cy = y.mean()
|
193 |
+
|
194 |
+
if bbox_cy - image_size / 2 < 0:
|
195 |
+
top = 0
|
196 |
+
elif bbox_cy + image_size / 2 > h:
|
197 |
+
top = h - image_size
|
198 |
+
else:
|
199 |
+
top = int(bbox_cy - image_size / 2)
|
200 |
+
|
201 |
+
if bbox_cx - image_size / 2 < 0:
|
202 |
+
left = 0
|
203 |
+
elif bbox_cx + image_size / 2 > w:
|
204 |
+
left = w - image_size
|
205 |
+
else:
|
206 |
+
left = int(bbox_cx - image_size / 2)
|
207 |
+
|
208 |
+
# top = max(int(bbox_cy - image_size / 2), 0)
|
209 |
+
# left = max(int(bbox_cx - image_size / 2), 0)
|
210 |
+
bottom = top + image_size
|
211 |
+
right = left + image_size
|
212 |
+
else:
|
213 |
+
raise ValueError(f"Unknown mask type: {self.mask_type}")
|
214 |
+
|
215 |
+
frame = frame.crop((left, top, right, bottom))
|
216 |
+
frame = frame.resize((self.reso, self.reso))
|
217 |
+
frames.append(self.transform(frame))
|
218 |
+
|
219 |
+
if self.load_pixelnerf:
|
220 |
+
# extrinsics
|
221 |
+
extrinsics = this_images[this_id]
|
222 |
+
c2w = qt2c2w(extrinsics.qvec, extrinsics.tvec)
|
223 |
+
# intrinsics
|
224 |
+
intrinsics = read_cameras_binary(this_camera_dir / "cameras.bin")
|
225 |
+
assert len(intrinsics) == 1
|
226 |
+
intrinsics = intrinsics[1]
|
227 |
+
f, cx, cy, _ = intrinsics.params
|
228 |
+
f *= 1 / image_size
|
229 |
+
cx -= left
|
230 |
+
cy -= top
|
231 |
+
cx *= 1 / image_size
|
232 |
+
cy *= 1 / image_size # all are relative values
|
233 |
+
intrinsics = np.array([[f, 0, cx], [0, f, cy], [0, 0, 1]])
|
234 |
+
|
235 |
+
this_camera = np.zeros(25)
|
236 |
+
this_camera[:16] = c2w.reshape(-1)
|
237 |
+
this_camera[16:] = intrinsics.reshape(-1)
|
238 |
+
|
239 |
+
cameras.append(this_camera)
|
240 |
+
downsampled = frame.resize((self.reso // 8, self.reso // 8))
|
241 |
+
downsampled_rgb.append((self.transform(downsampled) + 1.0) * 0.5)
|
242 |
+
|
243 |
+
data = dict()
|
244 |
+
|
245 |
+
cond_aug = np.exp(
|
246 |
+
np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
|
247 |
+
)
|
248 |
+
frames = torch.stack(frames)
|
249 |
+
cond = frames[0]
|
250 |
+
# setting all things in data
|
251 |
+
data["frames"] = frames
|
252 |
+
data["cond_frames_without_noise"] = cond
|
253 |
+
data["cond_aug"] = torch.as_tensor([cond_aug] * self.num_frames)
|
254 |
+
data["cond_frames"] = cond + cond_aug * torch.randn_like(cond)
|
255 |
+
data["fps_id"] = torch.as_tensor([self.fps_id] * self.num_frames)
|
256 |
+
data["motion_bucket_id"] = torch.as_tensor(
|
257 |
+
[self.motion_bucket_id] * self.num_frames
|
258 |
+
)
|
259 |
+
data["num_video_frames"] = self.num_frames
|
260 |
+
data["image_only_indicator"] = torch.as_tensor([0.0] * self.num_frames)
|
261 |
+
|
262 |
+
if self.load_pixelnerf:
|
263 |
+
# TODO: normalize camera poses
|
264 |
+
data["pixelnerf_input"] = dict()
|
265 |
+
data["pixelnerf_input"]["frames"] = frames
|
266 |
+
data["pixelnerf_input"]["rgb"] = torch.stack(downsampled_rgb)
|
267 |
+
|
268 |
+
cameras = torch.from_numpy(np.stack(cameras)).float()
|
269 |
+
if self.scale_pose:
|
270 |
+
c2ws = cameras[..., :16].reshape(-1, 4, 4)
|
271 |
+
center = c2ws[:, :3, 3].mean(0)
|
272 |
+
radius = (c2ws[:, :3, 3] - center).norm(dim=-1).max()
|
273 |
+
scale = 1.5 / radius
|
274 |
+
c2ws[..., :3, 3] = (c2ws[..., :3, 3] - center) * scale
|
275 |
+
cameras[..., :16] = c2ws.reshape(-1, 16)
|
276 |
+
|
277 |
+
# if self.max_n_cond > 1:
|
278 |
+
# # TODO implement this
|
279 |
+
# n_cond = np.random.randint(1, self.max_n_cond + 1)
|
280 |
+
# # debug
|
281 |
+
# source_index = [0]
|
282 |
+
# if n_cond > 1:
|
283 |
+
# source_index += np.random.choice(
|
284 |
+
# np.arange(1, self.num_frames),
|
285 |
+
# self.max_n_cond - 1,
|
286 |
+
# replace=False,
|
287 |
+
# ).tolist()
|
288 |
+
# data["pixelnerf_input"]["source_index"] = torch.as_tensor(
|
289 |
+
# source_index
|
290 |
+
# )
|
291 |
+
# data["pixelnerf_input"]["n_cond"] = n_cond
|
292 |
+
# data["pixelnerf_input"]["source_images"] = frames[source_index]
|
293 |
+
# data["pixelnerf_input"]["source_cameras"] = cameras[source_index]
|
294 |
+
|
295 |
+
data["pixelnerf_input"]["cameras"] = cameras
|
296 |
+
|
297 |
+
return data
|
298 |
+
|
299 |
+
def __len__(self):
|
300 |
+
return len(self.ids)
|
301 |
+
|
302 |
+
def collate_fn(self, batch):
|
303 |
+
# a hack to add source index and keep consistent within a batch
|
304 |
+
if self.max_n_cond > 1:
|
305 |
+
# TODO implement this
|
306 |
+
n_cond = np.random.randint(self.min_n_cond, self.max_n_cond + 1)
|
307 |
+
# debug
|
308 |
+
# source_index = [0]
|
309 |
+
if n_cond > 1:
|
310 |
+
for b in batch:
|
311 |
+
source_index = [0] + np.random.choice(
|
312 |
+
np.arange(1, self.num_frames),
|
313 |
+
self.max_n_cond - 1,
|
314 |
+
replace=False,
|
315 |
+
).tolist()
|
316 |
+
b["pixelnerf_input"]["source_index"] = torch.as_tensor(source_index)
|
317 |
+
b["pixelnerf_input"]["n_cond"] = n_cond
|
318 |
+
b["pixelnerf_input"]["source_images"] = b["frames"][source_index]
|
319 |
+
b["pixelnerf_input"]["source_cameras"] = b["pixelnerf_input"][
|
320 |
+
"cameras"
|
321 |
+
][source_index]
|
322 |
+
|
323 |
+
if self.cond_on_multi:
|
324 |
+
b["cond_frames_without_noise"] = b["frames"][source_index]
|
325 |
+
|
326 |
+
ret = video_collate_fn(batch)
|
327 |
+
|
328 |
+
if self.cond_on_multi:
|
329 |
+
ret["cond_frames_without_noise"] = rearrange(ret["cond_frames_without_noise"], "b t ... -> (b t) ...")
|
330 |
+
|
331 |
+
return ret
|
332 |
+
|
333 |
+
|
334 |
+
class MVImageNetFixedCond(MVImageNet):
|
335 |
+
def __init__(self, *args, **kwargs):
|
336 |
+
super().__init__(*args, **kwargs)
|
337 |
+
|
338 |
+
|
339 |
+
class MVImageNetDataset(LightningDataModule):
|
340 |
+
def __init__(
|
341 |
+
self,
|
342 |
+
root_dir,
|
343 |
+
batch_size=2,
|
344 |
+
shuffle=True,
|
345 |
+
num_workers=10,
|
346 |
+
prefetch_factor=2,
|
347 |
+
**kwargs,
|
348 |
+
):
|
349 |
+
super().__init__()
|
350 |
+
|
351 |
+
self.batch_size = batch_size
|
352 |
+
self.num_workers = num_workers
|
353 |
+
self.prefetch_factor = prefetch_factor
|
354 |
+
self.shuffle = shuffle
|
355 |
+
|
356 |
+
self.transform = Compose(
|
357 |
+
[
|
358 |
+
ToTensor(),
|
359 |
+
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
360 |
+
]
|
361 |
+
)
|
362 |
+
|
363 |
+
self.train_dataset = MVImageNet(
|
364 |
+
root_dir=root_dir,
|
365 |
+
split="train",
|
366 |
+
transform=self.transform,
|
367 |
+
**kwargs,
|
368 |
+
)
|
369 |
+
|
370 |
+
self.test_dataset = MVImageNet(
|
371 |
+
root_dir=root_dir,
|
372 |
+
split="test",
|
373 |
+
transform=self.transform,
|
374 |
+
**kwargs,
|
375 |
+
)
|
376 |
+
|
377 |
+
def train_dataloader(self):
|
378 |
+
def worker_init_fn(worker_id):
|
379 |
+
np.random.seed(np.random.get_state()[1][0])
|
380 |
+
|
381 |
+
return DataLoader(
|
382 |
+
self.train_dataset,
|
383 |
+
batch_size=self.batch_size,
|
384 |
+
shuffle=self.shuffle,
|
385 |
+
num_workers=self.num_workers,
|
386 |
+
prefetch_factor=self.prefetch_factor,
|
387 |
+
collate_fn=self.train_dataset.collate_fn,
|
388 |
+
)
|
389 |
+
|
390 |
+
def test_dataloader(self):
|
391 |
+
return DataLoader(
|
392 |
+
self.test_dataset,
|
393 |
+
batch_size=self.batch_size,
|
394 |
+
shuffle=self.shuffle,
|
395 |
+
num_workers=self.num_workers,
|
396 |
+
prefetch_factor=self.prefetch_factor,
|
397 |
+
collate_fn=self.test_dataset.collate_fn,
|
398 |
+
)
|
399 |
+
|
400 |
+
def val_dataloader(self):
|
401 |
+
return DataLoader(
|
402 |
+
self.test_dataset,
|
403 |
+
batch_size=self.batch_size,
|
404 |
+
shuffle=self.shuffle,
|
405 |
+
num_workers=self.num_workers,
|
406 |
+
prefetch_factor=self.prefetch_factor,
|
407 |
+
collate_fn=video_collate_fn,
|
408 |
+
)
|
sgm/data/objaverse.py
ADDED
@@ -0,0 +1,882 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from pathlib import Path
|
3 |
+
from PIL import Image
|
4 |
+
import json
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.utils.data import Dataset, DataLoader, default_collate
|
8 |
+
from torchvision.transforms import ToTensor, Normalize, Compose, Resize
|
9 |
+
from torchvision.transforms.functional import to_tensor
|
10 |
+
from pytorch_lightning import LightningDataModule
|
11 |
+
from einops import rearrange
|
12 |
+
|
13 |
+
|
14 |
+
def read_camera_matrix_single(json_file):
|
15 |
+
# for gobjaverse
|
16 |
+
with open(json_file, "r", encoding="utf8") as reader:
|
17 |
+
json_content = json.load(reader)
|
18 |
+
|
19 |
+
# negative sign for opencv to opengl
|
20 |
+
camera_matrix = torch.zeros(3, 4)
|
21 |
+
camera_matrix[:3, 0] = torch.tensor(json_content["x"])
|
22 |
+
camera_matrix[:3, 1] = -torch.tensor(json_content["y"])
|
23 |
+
camera_matrix[:3, 2] = -torch.tensor(json_content["z"])
|
24 |
+
camera_matrix[:3, 3] = torch.tensor(json_content["origin"])
|
25 |
+
"""
|
26 |
+
camera_matrix = np.eye(4)
|
27 |
+
camera_matrix[:3, 0] = np.array(json_content['x'])
|
28 |
+
camera_matrix[:3, 1] = np.array(json_content['y'])
|
29 |
+
camera_matrix[:3, 2] = np.array(json_content['z'])
|
30 |
+
camera_matrix[:3, 3] = np.array(json_content['origin'])
|
31 |
+
# print(camera_matrix)
|
32 |
+
"""
|
33 |
+
|
34 |
+
return camera_matrix
|
35 |
+
|
36 |
+
|
37 |
+
def read_camera_instrinsics_single(json_file, h: int, w: int, scale: float = 1.0):
|
38 |
+
with open(json_file, "r", encoding="utf8") as reader:
|
39 |
+
json_content = json.load(reader)
|
40 |
+
|
41 |
+
h = int(h * scale)
|
42 |
+
w = int(w * scale)
|
43 |
+
|
44 |
+
y_fov = json_content["y_fov"]
|
45 |
+
x_fov = json_content["x_fov"]
|
46 |
+
|
47 |
+
fy = h / 2 / np.tan(y_fov / 2)
|
48 |
+
fx = w / 2 / np.tan(x_fov / 2)
|
49 |
+
|
50 |
+
cx = w // 2
|
51 |
+
cy = h // 2
|
52 |
+
|
53 |
+
intrinsics = torch.tensor(
|
54 |
+
[
|
55 |
+
[fx, fy],
|
56 |
+
[cx, cy],
|
57 |
+
[w, h],
|
58 |
+
],
|
59 |
+
dtype=torch.float32,
|
60 |
+
)
|
61 |
+
return intrinsics
|
62 |
+
|
63 |
+
|
64 |
+
def compose_extrinsic_RT(RT: torch.Tensor):
|
65 |
+
"""
|
66 |
+
Compose the standard form extrinsic matrix from RT.
|
67 |
+
Batched I/O.
|
68 |
+
"""
|
69 |
+
return torch.cat(
|
70 |
+
[
|
71 |
+
RT,
|
72 |
+
torch.tensor([[[0, 0, 0, 1]]], dtype=torch.float32).repeat(
|
73 |
+
RT.shape[0], 1, 1
|
74 |
+
),
|
75 |
+
],
|
76 |
+
dim=1,
|
77 |
+
)
|
78 |
+
|
79 |
+
|
80 |
+
def get_normalized_camera_intrinsics(intrinsics: torch.Tensor):
|
81 |
+
"""
|
82 |
+
intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
|
83 |
+
Return batched fx, fy, cx, cy
|
84 |
+
"""
|
85 |
+
fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1]
|
86 |
+
cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1]
|
87 |
+
width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1]
|
88 |
+
fx, fy = fx / width, fy / height
|
89 |
+
cx, cy = cx / width, cy / height
|
90 |
+
return fx, fy, cx, cy
|
91 |
+
|
92 |
+
|
93 |
+
def build_camera_standard(RT: torch.Tensor, intrinsics: torch.Tensor):
|
94 |
+
"""
|
95 |
+
RT: (N, 3, 4)
|
96 |
+
intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
|
97 |
+
"""
|
98 |
+
E = compose_extrinsic_RT(RT)
|
99 |
+
fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics)
|
100 |
+
I = torch.stack(
|
101 |
+
[
|
102 |
+
torch.stack([fx, torch.zeros_like(fx), cx], dim=-1),
|
103 |
+
torch.stack([torch.zeros_like(fy), fy, cy], dim=-1),
|
104 |
+
torch.tensor([[0, 0, 1]], dtype=torch.float32).repeat(RT.shape[0], 1),
|
105 |
+
],
|
106 |
+
dim=1,
|
107 |
+
)
|
108 |
+
return torch.cat(
|
109 |
+
[
|
110 |
+
E.reshape(-1, 16),
|
111 |
+
I.reshape(-1, 9),
|
112 |
+
],
|
113 |
+
dim=-1,
|
114 |
+
)
|
115 |
+
|
116 |
+
|
117 |
+
def calc_elevation(c2w):
|
118 |
+
## works for single or batched c2w
|
119 |
+
## assume world up is (0, 0, 1)
|
120 |
+
pos = c2w[..., :3, 3]
|
121 |
+
|
122 |
+
return np.arcsin(pos[..., 2] / np.linalg.norm(pos, axis=-1, keepdims=False))
|
123 |
+
|
124 |
+
|
125 |
+
def read_camera_matrix_single(json_file):
|
126 |
+
with open(json_file, "r", encoding="utf8") as reader:
|
127 |
+
json_content = json.load(reader)
|
128 |
+
|
129 |
+
# negative sign for opencv to opengl
|
130 |
+
# camera_matrix = np.zeros([3, 4])
|
131 |
+
# camera_matrix[:3, 0] = np.array(json_content["x"])
|
132 |
+
# camera_matrix[:3, 1] = -np.array(json_content["y"])
|
133 |
+
# camera_matrix[:3, 2] = -np.array(json_content["z"])
|
134 |
+
# camera_matrix[:3, 3] = np.array(json_content["origin"])
|
135 |
+
camera_matrix = torch.zeros([3, 4])
|
136 |
+
camera_matrix[:3, 0] = torch.tensor(json_content["x"])
|
137 |
+
camera_matrix[:3, 1] = -torch.tensor(json_content["y"])
|
138 |
+
camera_matrix[:3, 2] = -torch.tensor(json_content["z"])
|
139 |
+
camera_matrix[:3, 3] = torch.tensor(json_content["origin"])
|
140 |
+
"""
|
141 |
+
camera_matrix = np.eye(4)
|
142 |
+
camera_matrix[:3, 0] = np.array(json_content['x'])
|
143 |
+
camera_matrix[:3, 1] = np.array(json_content['y'])
|
144 |
+
camera_matrix[:3, 2] = np.array(json_content['z'])
|
145 |
+
camera_matrix[:3, 3] = np.array(json_content['origin'])
|
146 |
+
# print(camera_matrix)
|
147 |
+
"""
|
148 |
+
|
149 |
+
return camera_matrix
|
150 |
+
|
151 |
+
|
152 |
+
def blend_white_bg(image):
|
153 |
+
new_image = Image.new("RGB", image.size, (255, 255, 255))
|
154 |
+
new_image.paste(image, mask=image.split()[3])
|
155 |
+
|
156 |
+
return new_image
|
157 |
+
|
158 |
+
|
159 |
+
def flatten_for_video(input):
|
160 |
+
return input.flatten()
|
161 |
+
|
162 |
+
|
163 |
+
FLATTEN_FIELDS = ["fps_id", "motion_bucket_id", "cond_aug", "elevation"]
|
164 |
+
|
165 |
+
|
166 |
+
def video_collate_fn(batch: list[dict], *args, **kwargs):
|
167 |
+
out = {}
|
168 |
+
for key in batch[0].keys():
|
169 |
+
if key in FLATTEN_FIELDS:
|
170 |
+
out[key] = default_collate([item[key] for item in batch])
|
171 |
+
out[key] = flatten_for_video(out[key])
|
172 |
+
elif key == "num_video_frames":
|
173 |
+
out[key] = batch[0][key]
|
174 |
+
elif key in ["frames", "latents", "rgb"]:
|
175 |
+
out[key] = default_collate([item[key] for item in batch])
|
176 |
+
out[key] = rearrange(out[key], "b t c h w -> (b t) c h w")
|
177 |
+
else:
|
178 |
+
out[key] = default_collate([item[key] for item in batch])
|
179 |
+
|
180 |
+
if "pixelnerf_input" in out:
|
181 |
+
out["pixelnerf_input"]["rgb"] = rearrange(
|
182 |
+
out["pixelnerf_input"]["rgb"], "b t c h w -> (b t) c h w"
|
183 |
+
)
|
184 |
+
|
185 |
+
return out
|
186 |
+
|
187 |
+
|
188 |
+
class GObjaverse(Dataset):
|
189 |
+
def __init__(
|
190 |
+
self,
|
191 |
+
root_dir,
|
192 |
+
split="train",
|
193 |
+
transform=None,
|
194 |
+
random_front=False,
|
195 |
+
max_item=None,
|
196 |
+
cond_aug_mean=-3.0,
|
197 |
+
cond_aug_std=0.5,
|
198 |
+
condition_on_elevation=False,
|
199 |
+
fps_id=0.0,
|
200 |
+
motion_bucket_id=300.0,
|
201 |
+
use_latents=False,
|
202 |
+
load_caps=False,
|
203 |
+
front_view_selection="random",
|
204 |
+
load_pixelnerf=False,
|
205 |
+
debug_base_idx=None,
|
206 |
+
scale_pose: bool = False,
|
207 |
+
max_n_cond: int = 1,
|
208 |
+
**unused_kwargs,
|
209 |
+
):
|
210 |
+
self.root_dir = Path(root_dir)
|
211 |
+
self.split = split
|
212 |
+
self.random_front = random_front
|
213 |
+
self.transform = transform
|
214 |
+
self.use_latents = use_latents
|
215 |
+
|
216 |
+
self.ids = json.load(open(self.root_dir / "valid_uids.json", "r"))
|
217 |
+
self.n_views = 24
|
218 |
+
|
219 |
+
self.load_caps = load_caps
|
220 |
+
if self.load_caps:
|
221 |
+
self.caps = json.load(open(self.root_dir / "text_captions_cap3d.json", "r"))
|
222 |
+
|
223 |
+
self.cond_aug_mean = cond_aug_mean
|
224 |
+
self.cond_aug_std = cond_aug_std
|
225 |
+
self.condition_on_elevation = condition_on_elevation
|
226 |
+
self.fps_id = fps_id
|
227 |
+
self.motion_bucket_id = motion_bucket_id
|
228 |
+
self.load_pixelnerf = load_pixelnerf
|
229 |
+
self.scale_pose = scale_pose
|
230 |
+
self.max_n_cond = max_n_cond
|
231 |
+
|
232 |
+
if self.use_latents:
|
233 |
+
self.latents_dir = self.root_dir / "latents256"
|
234 |
+
self.clip_dir = self.root_dir / "clip_emb256"
|
235 |
+
|
236 |
+
self.front_view_selection = front_view_selection
|
237 |
+
if self.front_view_selection == "random":
|
238 |
+
pass
|
239 |
+
elif self.front_view_selection == "fixed":
|
240 |
+
pass
|
241 |
+
elif self.front_view_selection.startswith("clip_score"):
|
242 |
+
self.clip_scores = torch.load(self.root_dir / "clip_score_per_view.pt")
|
243 |
+
self.ids = list(self.clip_scores.keys())
|
244 |
+
else:
|
245 |
+
raise ValueError(
|
246 |
+
f"Unknown front view selection method {self.front_view_selection}"
|
247 |
+
)
|
248 |
+
|
249 |
+
if max_item is not None:
|
250 |
+
self.ids = self.ids[:max_item]
|
251 |
+
## debug
|
252 |
+
self.ids = self.ids * 10000
|
253 |
+
|
254 |
+
if debug_base_idx is not None:
|
255 |
+
print(f"debug mode with base idx: {debug_base_idx}")
|
256 |
+
self.debug_base_idx = debug_base_idx
|
257 |
+
|
258 |
+
def __getitem__(self, idx: int):
|
259 |
+
if hasattr(self, "debug_base_idx"):
|
260 |
+
idx = (idx + self.debug_base_idx) % len(self.ids)
|
261 |
+
data = {}
|
262 |
+
idx_list = np.arange(self.n_views)
|
263 |
+
# if self.random_front:
|
264 |
+
# roll_idx = np.random.randint(self.n_views)
|
265 |
+
# idx_list = np.roll(idx_list, roll_idx)
|
266 |
+
if self.front_view_selection == "random":
|
267 |
+
roll_idx = np.random.randint(self.n_views)
|
268 |
+
idx_list = np.roll(idx_list, roll_idx)
|
269 |
+
elif self.front_view_selection == "fixed":
|
270 |
+
pass
|
271 |
+
elif self.front_view_selection == "clip_score_softmax":
|
272 |
+
this_clip_score = (
|
273 |
+
F.softmax(self.clip_scores[self.ids[idx]], dim=-1).cpu().numpy()
|
274 |
+
)
|
275 |
+
roll_idx = np.random.choice(idx_list, p=this_clip_score)
|
276 |
+
idx_list = np.roll(idx_list, roll_idx)
|
277 |
+
elif self.front_view_selection == "clip_score_max":
|
278 |
+
this_clip_score = (
|
279 |
+
F.softmax(self.clip_scores[self.ids[idx]], dim=-1).cpu().numpy()
|
280 |
+
)
|
281 |
+
roll_idx = np.argmax(this_clip_score)
|
282 |
+
idx_list = np.roll(idx_list, roll_idx)
|
283 |
+
frames = []
|
284 |
+
if not self.use_latents:
|
285 |
+
try:
|
286 |
+
for view_idx in idx_list:
|
287 |
+
frame = Image.open(
|
288 |
+
self.root_dir
|
289 |
+
/ "gobjaverse"
|
290 |
+
/ self.ids[idx]
|
291 |
+
/ f"{view_idx:05d}/{view_idx:05d}.png"
|
292 |
+
)
|
293 |
+
frames.append(self.transform(frame))
|
294 |
+
except:
|
295 |
+
idx = 0
|
296 |
+
frames = []
|
297 |
+
for view_idx in idx_list:
|
298 |
+
frame = Image.open(
|
299 |
+
self.root_dir
|
300 |
+
/ "gobjaverse"
|
301 |
+
/ self.ids[idx]
|
302 |
+
/ f"{view_idx:05d}/{view_idx:05d}.png"
|
303 |
+
)
|
304 |
+
frames.append(self.transform(frame))
|
305 |
+
# a workaround for some bugs in gobjaverse
|
306 |
+
# use idx=0 and the repeat will be resolved when gathering results, valid number of items can be checked by the len of results
|
307 |
+
frames = torch.stack(frames, dim=0)
|
308 |
+
cond = frames[0]
|
309 |
+
|
310 |
+
cond_aug = np.exp(
|
311 |
+
np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
|
312 |
+
)
|
313 |
+
|
314 |
+
data.update(
|
315 |
+
{
|
316 |
+
"frames": frames,
|
317 |
+
"cond_frames_without_noise": cond,
|
318 |
+
"cond_aug": torch.as_tensor([cond_aug] * self.n_views),
|
319 |
+
"cond_frames": cond + cond_aug * torch.randn_like(cond),
|
320 |
+
"fps_id": torch.as_tensor([self.fps_id] * self.n_views),
|
321 |
+
"motion_bucket_id": torch.as_tensor(
|
322 |
+
[self.motion_bucket_id] * self.n_views
|
323 |
+
),
|
324 |
+
"num_video_frames": 24,
|
325 |
+
"image_only_indicator": torch.as_tensor([0.0] * self.n_views),
|
326 |
+
}
|
327 |
+
)
|
328 |
+
else:
|
329 |
+
latents = torch.load(self.latents_dir / f"{self.ids[idx]}.pt")[idx_list]
|
330 |
+
clip_emb = torch.load(self.clip_dir / f"{self.ids[idx]}.pt")[idx_list][0]
|
331 |
+
|
332 |
+
cond = latents[0]
|
333 |
+
|
334 |
+
cond_aug = np.exp(
|
335 |
+
np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
|
336 |
+
)
|
337 |
+
|
338 |
+
data.update(
|
339 |
+
{
|
340 |
+
"latents": latents,
|
341 |
+
"cond_frames_without_noise": clip_emb,
|
342 |
+
"cond_aug": torch.as_tensor([cond_aug] * self.n_views),
|
343 |
+
"cond_frames": cond + cond_aug * torch.randn_like(cond),
|
344 |
+
"fps_id": torch.as_tensor([self.fps_id] * self.n_views),
|
345 |
+
"motion_bucket_id": torch.as_tensor(
|
346 |
+
[self.motion_bucket_id] * self.n_views
|
347 |
+
),
|
348 |
+
"num_video_frames": 24,
|
349 |
+
"image_only_indicator": torch.as_tensor([0.0] * self.n_views),
|
350 |
+
}
|
351 |
+
)
|
352 |
+
|
353 |
+
if self.condition_on_elevation:
|
354 |
+
sample_c2w = read_camera_matrix_single(
|
355 |
+
self.root_dir / self.ids[idx] / f"00000/00000.json"
|
356 |
+
)
|
357 |
+
elevation = calc_elevation(sample_c2w)
|
358 |
+
data["elevation"] = torch.as_tensor([elevation] * self.n_views)
|
359 |
+
|
360 |
+
if self.load_pixelnerf:
|
361 |
+
assert "frames" in data, f"pixelnerf cannot work with latents only mode"
|
362 |
+
data["pixelnerf_input"] = {}
|
363 |
+
RTs = []
|
364 |
+
intrinsics = []
|
365 |
+
for view_idx in idx_list:
|
366 |
+
meta = (
|
367 |
+
self.root_dir
|
368 |
+
/ "gobjaverse"
|
369 |
+
/ self.ids[idx]
|
370 |
+
/ f"{view_idx:05d}/{view_idx:05d}.json"
|
371 |
+
)
|
372 |
+
RTs.append(read_camera_matrix_single(meta)[:3])
|
373 |
+
intrinsics.append(read_camera_instrinsics_single(meta, 256, 256))
|
374 |
+
RTs = torch.stack(RTs, dim=0)
|
375 |
+
intrinsics = torch.stack(intrinsics, dim=0)
|
376 |
+
cameras = build_camera_standard(RTs, intrinsics)
|
377 |
+
data["pixelnerf_input"]["cameras"] = cameras
|
378 |
+
|
379 |
+
downsampled = []
|
380 |
+
for view_idx in idx_list:
|
381 |
+
frame = Image.open(
|
382 |
+
self.root_dir
|
383 |
+
/ "gobjaverse"
|
384 |
+
/ self.ids[idx]
|
385 |
+
/ f"{view_idx:05d}/{view_idx:05d}.png"
|
386 |
+
).resize((32, 32))
|
387 |
+
downsampled.append(to_tensor(blend_white_bg(frame)))
|
388 |
+
data["pixelnerf_input"]["rgb"] = torch.stack(downsampled, dim=0)
|
389 |
+
data["pixelnerf_input"]["frames"] = data["frames"]
|
390 |
+
if self.scale_pose:
|
391 |
+
c2ws = cameras[..., :16].reshape(-1, 4, 4)
|
392 |
+
center = c2ws[:, :3, 3].mean(0)
|
393 |
+
radius = (c2ws[:, :3, 3] - center).norm(dim=-1).max()
|
394 |
+
scale = 1.5 / radius
|
395 |
+
c2ws[..., :3, 3] = (c2ws[..., :3, 3] - center) * scale
|
396 |
+
cameras[..., :16] = c2ws.reshape(-1, 16)
|
397 |
+
|
398 |
+
if self.load_caps:
|
399 |
+
data["caption"] = self.caps[self.ids[idx]]
|
400 |
+
data["ids"] = self.ids[idx]
|
401 |
+
|
402 |
+
return data
|
403 |
+
|
404 |
+
def __len__(self):
|
405 |
+
return len(self.ids)
|
406 |
+
|
407 |
+
def collate_fn(self, batch):
|
408 |
+
if self.max_n_cond > 1:
|
409 |
+
n_cond = np.random.randint(1, self.max_n_cond + 1)
|
410 |
+
if n_cond > 1:
|
411 |
+
for b in batch:
|
412 |
+
source_index = [0] + np.random.choice(
|
413 |
+
np.arange(1, self.n_views),
|
414 |
+
self.max_n_cond - 1,
|
415 |
+
replace=False,
|
416 |
+
).tolist()
|
417 |
+
b["pixelnerf_input"]["source_index"] = torch.as_tensor(source_index)
|
418 |
+
b["pixelnerf_input"]["n_cond"] = n_cond
|
419 |
+
b["pixelnerf_input"]["source_images"] = b["frames"][source_index]
|
420 |
+
b["pixelnerf_input"]["source_cameras"] = b["pixelnerf_input"][
|
421 |
+
"cameras"
|
422 |
+
][source_index]
|
423 |
+
|
424 |
+
return video_collate_fn(batch)
|
425 |
+
|
426 |
+
|
427 |
+
class ObjaverseSpiral(Dataset):
|
428 |
+
def __init__(
|
429 |
+
self,
|
430 |
+
root_dir,
|
431 |
+
split="train",
|
432 |
+
transform=None,
|
433 |
+
random_front=False,
|
434 |
+
max_item=None,
|
435 |
+
cond_aug_mean=-3.0,
|
436 |
+
cond_aug_std=0.5,
|
437 |
+
condition_on_elevation=False,
|
438 |
+
**unused_kwargs,
|
439 |
+
):
|
440 |
+
self.root_dir = Path(root_dir)
|
441 |
+
self.split = split
|
442 |
+
self.random_front = random_front
|
443 |
+
self.transform = transform
|
444 |
+
|
445 |
+
self.ids = json.load(open(self.root_dir / f"{split}_ids.json", "r"))
|
446 |
+
self.n_views = 24
|
447 |
+
valid_ids = []
|
448 |
+
for idx in self.ids:
|
449 |
+
if (self.root_dir / idx).exists():
|
450 |
+
valid_ids.append(idx)
|
451 |
+
self.ids = valid_ids
|
452 |
+
|
453 |
+
self.cond_aug_mean = cond_aug_mean
|
454 |
+
self.cond_aug_std = cond_aug_std
|
455 |
+
self.condition_on_elevation = condition_on_elevation
|
456 |
+
|
457 |
+
if max_item is not None:
|
458 |
+
self.ids = self.ids[:max_item]
|
459 |
+
|
460 |
+
## debug
|
461 |
+
self.ids = self.ids * 10000
|
462 |
+
|
463 |
+
def __getitem__(self, idx: int):
|
464 |
+
frames = []
|
465 |
+
idx_list = np.arange(self.n_views)
|
466 |
+
if self.random_front:
|
467 |
+
roll_idx = np.random.randint(self.n_views)
|
468 |
+
idx_list = np.roll(idx_list, roll_idx)
|
469 |
+
for view_idx in idx_list:
|
470 |
+
frame = Image.open(
|
471 |
+
self.root_dir / self.ids[idx] / f"{view_idx:05d}/{view_idx:05d}.png"
|
472 |
+
)
|
473 |
+
frames.append(self.transform(frame))
|
474 |
+
|
475 |
+
# data = {"jpg": torch.stack(frames, dim=0)} # [T, C, H, W]
|
476 |
+
frames = torch.stack(frames, dim=0)
|
477 |
+
cond = frames[0]
|
478 |
+
|
479 |
+
cond_aug = np.exp(
|
480 |
+
np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
|
481 |
+
)
|
482 |
+
|
483 |
+
data = {
|
484 |
+
"frames": frames,
|
485 |
+
"cond_frames_without_noise": cond,
|
486 |
+
"cond_aug": torch.as_tensor([cond_aug] * self.n_views),
|
487 |
+
"cond_frames": cond + cond_aug * torch.randn_like(cond),
|
488 |
+
"fps_id": torch.as_tensor([1.0] * self.n_views),
|
489 |
+
"motion_bucket_id": torch.as_tensor([300.0] * self.n_views),
|
490 |
+
"num_video_frames": 24,
|
491 |
+
"image_only_indicator": torch.as_tensor([0.0] * self.n_views),
|
492 |
+
}
|
493 |
+
|
494 |
+
if self.condition_on_elevation:
|
495 |
+
sample_c2w = read_camera_matrix_single(
|
496 |
+
self.root_dir / self.ids[idx] / f"00000/00000.json"
|
497 |
+
)
|
498 |
+
elevation = calc_elevation(sample_c2w)
|
499 |
+
data["elevation"] = torch.as_tensor([elevation] * self.n_views)
|
500 |
+
|
501 |
+
return data
|
502 |
+
|
503 |
+
def __len__(self):
|
504 |
+
return len(self.ids)
|
505 |
+
|
506 |
+
|
507 |
+
class ObjaverseLVISSpiral(Dataset):
|
508 |
+
def __init__(
|
509 |
+
self,
|
510 |
+
root_dir,
|
511 |
+
split="train",
|
512 |
+
transform=None,
|
513 |
+
random_front=False,
|
514 |
+
max_item=None,
|
515 |
+
cond_aug_mean=-3.0,
|
516 |
+
cond_aug_std=0.5,
|
517 |
+
condition_on_elevation=False,
|
518 |
+
use_precomputed_latents=False,
|
519 |
+
**unused_kwargs,
|
520 |
+
):
|
521 |
+
print("Using LVIS subset")
|
522 |
+
self.root_dir = Path(root_dir)
|
523 |
+
self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512")
|
524 |
+
self.split = split
|
525 |
+
self.random_front = random_front
|
526 |
+
self.transform = transform
|
527 |
+
self.use_precomputed_latents = use_precomputed_latents
|
528 |
+
|
529 |
+
self.ids = json.load(open("./assets/lvis_uids.json", "r"))
|
530 |
+
self.n_views = 18
|
531 |
+
valid_ids = []
|
532 |
+
for idx in self.ids:
|
533 |
+
if (self.root_dir / idx).exists():
|
534 |
+
valid_ids.append(idx)
|
535 |
+
self.ids = valid_ids
|
536 |
+
print("=" * 30)
|
537 |
+
print("Number of valid ids: ", len(self.ids))
|
538 |
+
print("=" * 30)
|
539 |
+
|
540 |
+
self.cond_aug_mean = cond_aug_mean
|
541 |
+
self.cond_aug_std = cond_aug_std
|
542 |
+
self.condition_on_elevation = condition_on_elevation
|
543 |
+
|
544 |
+
if max_item is not None:
|
545 |
+
self.ids = self.ids[:max_item]
|
546 |
+
|
547 |
+
## debug
|
548 |
+
self.ids = self.ids * 10000
|
549 |
+
|
550 |
+
def __getitem__(self, idx: int):
|
551 |
+
frames = []
|
552 |
+
idx_list = np.arange(self.n_views)
|
553 |
+
if self.random_front:
|
554 |
+
roll_idx = np.random.randint(self.n_views)
|
555 |
+
idx_list = np.roll(idx_list, roll_idx)
|
556 |
+
for view_idx in idx_list:
|
557 |
+
frame = Image.open(
|
558 |
+
self.root_dir
|
559 |
+
/ self.ids[idx]
|
560 |
+
/ "elevations_0"
|
561 |
+
/ f"colors_{view_idx * 2}.png"
|
562 |
+
)
|
563 |
+
frames.append(self.transform(frame))
|
564 |
+
|
565 |
+
frames = torch.stack(frames, dim=0)
|
566 |
+
cond = frames[0]
|
567 |
+
|
568 |
+
cond_aug = np.exp(
|
569 |
+
np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
|
570 |
+
)
|
571 |
+
|
572 |
+
data = {
|
573 |
+
"frames": frames,
|
574 |
+
"cond_frames_without_noise": cond,
|
575 |
+
"cond_aug": torch.as_tensor([cond_aug] * self.n_views),
|
576 |
+
"cond_frames": cond + cond_aug * torch.randn_like(cond),
|
577 |
+
"fps_id": torch.as_tensor([0.0] * self.n_views),
|
578 |
+
"motion_bucket_id": torch.as_tensor([300.0] * self.n_views),
|
579 |
+
"num_video_frames": self.n_views,
|
580 |
+
"image_only_indicator": torch.as_tensor([0.0] * self.n_views),
|
581 |
+
}
|
582 |
+
|
583 |
+
if self.use_precomputed_latents:
|
584 |
+
data["latents"] = torch.load(self.latent_dir / f"{self.ids[idx]}.pt")
|
585 |
+
|
586 |
+
if self.condition_on_elevation:
|
587 |
+
# sample_c2w = read_camera_matrix_single(
|
588 |
+
# self.root_dir / self.ids[idx] / f"00000/00000.json"
|
589 |
+
# )
|
590 |
+
# elevation = calc_elevation(sample_c2w)
|
591 |
+
# data["elevation"] = torch.as_tensor([elevation] * self.n_views)
|
592 |
+
assert False, "currently assumes elevation 0"
|
593 |
+
|
594 |
+
return data
|
595 |
+
|
596 |
+
def __len__(self):
|
597 |
+
return len(self.ids)
|
598 |
+
|
599 |
+
|
600 |
+
class ObjaverseALLSpiral(ObjaverseLVISSpiral):
|
601 |
+
def __init__(
|
602 |
+
self,
|
603 |
+
root_dir,
|
604 |
+
split="train",
|
605 |
+
transform=None,
|
606 |
+
random_front=False,
|
607 |
+
max_item=None,
|
608 |
+
cond_aug_mean=-3.0,
|
609 |
+
cond_aug_std=0.5,
|
610 |
+
condition_on_elevation=False,
|
611 |
+
use_precomputed_latents=False,
|
612 |
+
**unused_kwargs,
|
613 |
+
):
|
614 |
+
print("Using ALL objects in Objaverse")
|
615 |
+
self.root_dir = Path(root_dir)
|
616 |
+
self.split = split
|
617 |
+
self.random_front = random_front
|
618 |
+
self.transform = transform
|
619 |
+
self.use_precomputed_latents = use_precomputed_latents
|
620 |
+
self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512")
|
621 |
+
|
622 |
+
self.ids = json.load(open("./assets/all_ids.json", "r"))
|
623 |
+
self.n_views = 18
|
624 |
+
valid_ids = []
|
625 |
+
for idx in self.ids:
|
626 |
+
if (self.root_dir / idx).exists() and (self.root_dir / idx).is_dir():
|
627 |
+
valid_ids.append(idx)
|
628 |
+
self.ids = valid_ids
|
629 |
+
print("=" * 30)
|
630 |
+
print("Number of valid ids: ", len(self.ids))
|
631 |
+
print("=" * 30)
|
632 |
+
|
633 |
+
self.cond_aug_mean = cond_aug_mean
|
634 |
+
self.cond_aug_std = cond_aug_std
|
635 |
+
self.condition_on_elevation = condition_on_elevation
|
636 |
+
|
637 |
+
if max_item is not None:
|
638 |
+
self.ids = self.ids[:max_item]
|
639 |
+
|
640 |
+
## debug
|
641 |
+
self.ids = self.ids * 10000
|
642 |
+
|
643 |
+
|
644 |
+
class ObjaverseWithPose(Dataset):
|
645 |
+
def __init__(
|
646 |
+
self,
|
647 |
+
root_dir,
|
648 |
+
split="train",
|
649 |
+
transform=None,
|
650 |
+
random_front=False,
|
651 |
+
max_item=None,
|
652 |
+
cond_aug_mean=-3.0,
|
653 |
+
cond_aug_std=0.5,
|
654 |
+
condition_on_elevation=False,
|
655 |
+
use_precomputed_latents=False,
|
656 |
+
**unused_kwargs,
|
657 |
+
):
|
658 |
+
print("Using Objaverse with poses")
|
659 |
+
self.root_dir = Path(root_dir)
|
660 |
+
self.split = split
|
661 |
+
self.random_front = random_front
|
662 |
+
self.transform = transform
|
663 |
+
self.use_precomputed_latents = use_precomputed_latents
|
664 |
+
self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512")
|
665 |
+
|
666 |
+
self.ids = json.load(open("./assets/all_ids.json", "r"))
|
667 |
+
self.n_views = 18
|
668 |
+
valid_ids = []
|
669 |
+
for idx in self.ids:
|
670 |
+
if (self.root_dir / idx).exists() and (self.root_dir / idx).is_dir():
|
671 |
+
valid_ids.append(idx)
|
672 |
+
self.ids = valid_ids
|
673 |
+
print("=" * 30)
|
674 |
+
print("Number of valid ids: ", len(self.ids))
|
675 |
+
print("=" * 30)
|
676 |
+
|
677 |
+
self.cond_aug_mean = cond_aug_mean
|
678 |
+
self.cond_aug_std = cond_aug_std
|
679 |
+
self.condition_on_elevation = condition_on_elevation
|
680 |
+
|
681 |
+
def __getitem__(self, idx: int):
|
682 |
+
frames = []
|
683 |
+
idx_list = np.arange(self.n_views)
|
684 |
+
if self.random_front:
|
685 |
+
roll_idx = np.random.randint(self.n_views)
|
686 |
+
idx_list = np.roll(idx_list, roll_idx)
|
687 |
+
for view_idx in idx_list:
|
688 |
+
frame = Image.open(
|
689 |
+
self.root_dir
|
690 |
+
/ self.ids[idx]
|
691 |
+
/ "elevations_0"
|
692 |
+
/ f"colors_{view_idx * 2}.png"
|
693 |
+
)
|
694 |
+
frames.append(self.transform(frame))
|
695 |
+
|
696 |
+
frames = torch.stack(frames, dim=0)
|
697 |
+
cond = frames[0]
|
698 |
+
|
699 |
+
cond_aug = np.exp(
|
700 |
+
np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
|
701 |
+
)
|
702 |
+
|
703 |
+
data = {
|
704 |
+
"frames": frames,
|
705 |
+
"cond_frames_without_noise": cond,
|
706 |
+
"cond_aug": torch.as_tensor([cond_aug] * self.n_views),
|
707 |
+
"cond_frames": cond + cond_aug * torch.randn_like(cond),
|
708 |
+
"fps_id": torch.as_tensor([0.0] * self.n_views),
|
709 |
+
"motion_bucket_id": torch.as_tensor([300.0] * self.n_views),
|
710 |
+
"num_video_frames": self.n_views,
|
711 |
+
"image_only_indicator": torch.as_tensor([0.0] * self.n_views),
|
712 |
+
}
|
713 |
+
|
714 |
+
if self.use_precomputed_latents:
|
715 |
+
data["latents"] = torch.load(self.latent_dir / f"{self.ids[idx]}.pt")
|
716 |
+
|
717 |
+
if self.condition_on_elevation:
|
718 |
+
assert False, "currently assumes elevation 0"
|
719 |
+
|
720 |
+
return data
|
721 |
+
|
722 |
+
|
723 |
+
class LatentObjaverse(Dataset):
|
724 |
+
def __init__(
|
725 |
+
self,
|
726 |
+
root_dir,
|
727 |
+
split="train",
|
728 |
+
random_front=False,
|
729 |
+
subset="lvis",
|
730 |
+
fps_id=1.0,
|
731 |
+
motion_bucket_id=300.0,
|
732 |
+
cond_aug_mean=-3.0,
|
733 |
+
cond_aug_std=0.5,
|
734 |
+
**unused_kwargs,
|
735 |
+
):
|
736 |
+
self.root_dir = Path(root_dir)
|
737 |
+
self.split = split
|
738 |
+
self.random_front = random_front
|
739 |
+
self.ids = json.load(open(Path("./assets") / f"{subset}_ids.json", "r"))
|
740 |
+
self.clip_emb_dir = self.root_dir / ".." / "clip_emb512"
|
741 |
+
self.n_views = 18
|
742 |
+
self.fps_id = fps_id
|
743 |
+
self.motion_bucket_id = motion_bucket_id
|
744 |
+
self.cond_aug_mean = cond_aug_mean
|
745 |
+
self.cond_aug_std = cond_aug_std
|
746 |
+
if self.random_front:
|
747 |
+
print("Using a random view as front view")
|
748 |
+
|
749 |
+
valid_ids = []
|
750 |
+
for idx in self.ids:
|
751 |
+
if (self.root_dir / f"{idx}.pt").exists() and (
|
752 |
+
self.clip_emb_dir / f"{idx}.pt"
|
753 |
+
).exists():
|
754 |
+
valid_ids.append(idx)
|
755 |
+
self.ids = valid_ids
|
756 |
+
print("=" * 30)
|
757 |
+
print("Number of valid ids: ", len(self.ids))
|
758 |
+
print("=" * 30)
|
759 |
+
|
760 |
+
def __getitem__(self, idx: int):
|
761 |
+
uid = self.ids[idx]
|
762 |
+
idx_list = torch.arange(self.n_views)
|
763 |
+
latents = torch.load(self.root_dir / f"{uid}.pt")
|
764 |
+
clip_emb = torch.load(self.clip_emb_dir / f"{uid}.pt")
|
765 |
+
if self.random_front:
|
766 |
+
idx_list = torch.roll(idx_list, np.random.randint(self.n_views))
|
767 |
+
latents = latents[idx_list]
|
768 |
+
clip_emb = clip_emb[idx_list][0]
|
769 |
+
|
770 |
+
cond_aug = np.exp(
|
771 |
+
np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
|
772 |
+
)
|
773 |
+
cond = latents[0]
|
774 |
+
|
775 |
+
data = {
|
776 |
+
"latents": latents,
|
777 |
+
"cond_frames_without_noise": clip_emb,
|
778 |
+
"cond_frames": cond + cond_aug * torch.randn_like(cond),
|
779 |
+
"fps_id": torch.as_tensor([self.fps_id] * self.n_views),
|
780 |
+
"motion_bucket_id": torch.as_tensor([self.motion_bucket_id] * self.n_views),
|
781 |
+
"cond_aug": torch.as_tensor([cond_aug] * self.n_views),
|
782 |
+
"num_video_frames": self.n_views,
|
783 |
+
"image_only_indicator": torch.as_tensor([0.0] * self.n_views),
|
784 |
+
}
|
785 |
+
|
786 |
+
return data
|
787 |
+
|
788 |
+
def __len__(self):
|
789 |
+
return len(self.ids)
|
790 |
+
|
791 |
+
|
792 |
+
class ObjaverseSpiralDataset(LightningDataModule):
|
793 |
+
def __init__(
|
794 |
+
self,
|
795 |
+
root_dir,
|
796 |
+
random_front=False,
|
797 |
+
batch_size=2,
|
798 |
+
num_workers=10,
|
799 |
+
prefetch_factor=2,
|
800 |
+
shuffle=True,
|
801 |
+
max_item=None,
|
802 |
+
dataset_cls="richdreamer",
|
803 |
+
reso: int = 256,
|
804 |
+
**kwargs,
|
805 |
+
) -> None:
|
806 |
+
super().__init__()
|
807 |
+
|
808 |
+
self.batch_size = batch_size
|
809 |
+
self.num_workers = num_workers
|
810 |
+
self.prefetch_factor = prefetch_factor
|
811 |
+
self.shuffle = shuffle
|
812 |
+
self.max_item = max_item
|
813 |
+
|
814 |
+
self.transform = Compose(
|
815 |
+
[
|
816 |
+
blend_white_bg,
|
817 |
+
Resize((reso, reso)),
|
818 |
+
ToTensor(),
|
819 |
+
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
820 |
+
]
|
821 |
+
)
|
822 |
+
|
823 |
+
data_cls = {
|
824 |
+
"richdreamer": ObjaverseSpiral,
|
825 |
+
"lvis": ObjaverseLVISSpiral,
|
826 |
+
"shengshu_all": ObjaverseALLSpiral,
|
827 |
+
"latent": LatentObjaverse,
|
828 |
+
"gobjaverse": GObjaverse,
|
829 |
+
}[dataset_cls]
|
830 |
+
|
831 |
+
self.train_dataset = data_cls(
|
832 |
+
root_dir=root_dir,
|
833 |
+
split="train",
|
834 |
+
random_front=random_front,
|
835 |
+
transform=self.transform,
|
836 |
+
max_item=self.max_item,
|
837 |
+
**kwargs,
|
838 |
+
)
|
839 |
+
self.test_dataset = data_cls(
|
840 |
+
root_dir=root_dir,
|
841 |
+
split="val",
|
842 |
+
random_front=random_front,
|
843 |
+
transform=self.transform,
|
844 |
+
max_item=self.max_item,
|
845 |
+
**kwargs,
|
846 |
+
)
|
847 |
+
|
848 |
+
def train_dataloader(self):
|
849 |
+
return DataLoader(
|
850 |
+
self.train_dataset,
|
851 |
+
batch_size=self.batch_size,
|
852 |
+
shuffle=self.shuffle,
|
853 |
+
num_workers=self.num_workers,
|
854 |
+
prefetch_factor=self.prefetch_factor,
|
855 |
+
collate_fn=video_collate_fn
|
856 |
+
if not hasattr(self.train_dataset, "collate_fn")
|
857 |
+
else self.train_dataset.collate_fn,
|
858 |
+
)
|
859 |
+
|
860 |
+
def test_dataloader(self):
|
861 |
+
return DataLoader(
|
862 |
+
self.test_dataset,
|
863 |
+
batch_size=self.batch_size,
|
864 |
+
shuffle=self.shuffle,
|
865 |
+
num_workers=self.num_workers,
|
866 |
+
prefetch_factor=self.prefetch_factor,
|
867 |
+
collate_fn=video_collate_fn
|
868 |
+
if not hasattr(self.test_dataset, "collate_fn")
|
869 |
+
else self.train_dataset.collate_fn,
|
870 |
+
)
|
871 |
+
|
872 |
+
def val_dataloader(self):
|
873 |
+
return DataLoader(
|
874 |
+
self.test_dataset,
|
875 |
+
batch_size=self.batch_size,
|
876 |
+
shuffle=self.shuffle,
|
877 |
+
num_workers=self.num_workers,
|
878 |
+
prefetch_factor=self.prefetch_factor,
|
879 |
+
collate_fn=video_collate_fn
|
880 |
+
if not hasattr(self.test_dataset, "collate_fn")
|
881 |
+
else self.train_dataset.collate_fn,
|
882 |
+
)
|
sgm/inference/api.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
from dataclasses import asdict, dataclass
|
3 |
+
from enum import Enum
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
|
8 |
+
from sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img,
|
9 |
+
do_sample)
|
10 |
+
from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
|
11 |
+
DPMPP2SAncestralSampler,
|
12 |
+
EulerAncestralSampler,
|
13 |
+
EulerEDMSampler,
|
14 |
+
HeunEDMSampler,
|
15 |
+
LinearMultistepSampler)
|
16 |
+
from sgm.util import load_model_from_config
|
17 |
+
|
18 |
+
|
19 |
+
class ModelArchitecture(str, Enum):
|
20 |
+
SD_2_1 = "stable-diffusion-v2-1"
|
21 |
+
SD_2_1_768 = "stable-diffusion-v2-1-768"
|
22 |
+
SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
|
23 |
+
SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
|
24 |
+
SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
|
25 |
+
SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
|
26 |
+
|
27 |
+
|
28 |
+
class Sampler(str, Enum):
|
29 |
+
EULER_EDM = "EulerEDMSampler"
|
30 |
+
HEUN_EDM = "HeunEDMSampler"
|
31 |
+
EULER_ANCESTRAL = "EulerAncestralSampler"
|
32 |
+
DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler"
|
33 |
+
DPMPP2M = "DPMPP2MSampler"
|
34 |
+
LINEAR_MULTISTEP = "LinearMultistepSampler"
|
35 |
+
|
36 |
+
|
37 |
+
class Discretization(str, Enum):
|
38 |
+
LEGACY_DDPM = "LegacyDDPMDiscretization"
|
39 |
+
EDM = "EDMDiscretization"
|
40 |
+
|
41 |
+
|
42 |
+
class Guider(str, Enum):
|
43 |
+
VANILLA = "VanillaCFG"
|
44 |
+
IDENTITY = "IdentityGuider"
|
45 |
+
|
46 |
+
|
47 |
+
class Thresholder(str, Enum):
|
48 |
+
NONE = "None"
|
49 |
+
|
50 |
+
|
51 |
+
@dataclass
|
52 |
+
class SamplingParams:
|
53 |
+
width: int = 1024
|
54 |
+
height: int = 1024
|
55 |
+
steps: int = 50
|
56 |
+
sampler: Sampler = Sampler.DPMPP2M
|
57 |
+
discretization: Discretization = Discretization.LEGACY_DDPM
|
58 |
+
guider: Guider = Guider.VANILLA
|
59 |
+
thresholder: Thresholder = Thresholder.NONE
|
60 |
+
scale: float = 6.0
|
61 |
+
aesthetic_score: float = 5.0
|
62 |
+
negative_aesthetic_score: float = 5.0
|
63 |
+
img2img_strength: float = 1.0
|
64 |
+
orig_width: int = 1024
|
65 |
+
orig_height: int = 1024
|
66 |
+
crop_coords_top: int = 0
|
67 |
+
crop_coords_left: int = 0
|
68 |
+
sigma_min: float = 0.0292
|
69 |
+
sigma_max: float = 14.6146
|
70 |
+
rho: float = 3.0
|
71 |
+
s_churn: float = 0.0
|
72 |
+
s_tmin: float = 0.0
|
73 |
+
s_tmax: float = 999.0
|
74 |
+
s_noise: float = 1.0
|
75 |
+
eta: float = 1.0
|
76 |
+
order: int = 4
|
77 |
+
|
78 |
+
|
79 |
+
@dataclass
|
80 |
+
class SamplingSpec:
|
81 |
+
width: int
|
82 |
+
height: int
|
83 |
+
channels: int
|
84 |
+
factor: int
|
85 |
+
is_legacy: bool
|
86 |
+
config: str
|
87 |
+
ckpt: str
|
88 |
+
is_guided: bool
|
89 |
+
|
90 |
+
|
91 |
+
model_specs = {
|
92 |
+
ModelArchitecture.SD_2_1: SamplingSpec(
|
93 |
+
height=512,
|
94 |
+
width=512,
|
95 |
+
channels=4,
|
96 |
+
factor=8,
|
97 |
+
is_legacy=True,
|
98 |
+
config="sd_2_1.yaml",
|
99 |
+
ckpt="v2-1_512-ema-pruned.safetensors",
|
100 |
+
is_guided=True,
|
101 |
+
),
|
102 |
+
ModelArchitecture.SD_2_1_768: SamplingSpec(
|
103 |
+
height=768,
|
104 |
+
width=768,
|
105 |
+
channels=4,
|
106 |
+
factor=8,
|
107 |
+
is_legacy=True,
|
108 |
+
config="sd_2_1_768.yaml",
|
109 |
+
ckpt="v2-1_768-ema-pruned.safetensors",
|
110 |
+
is_guided=True,
|
111 |
+
),
|
112 |
+
ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
|
113 |
+
height=1024,
|
114 |
+
width=1024,
|
115 |
+
channels=4,
|
116 |
+
factor=8,
|
117 |
+
is_legacy=False,
|
118 |
+
config="sd_xl_base.yaml",
|
119 |
+
ckpt="sd_xl_base_0.9.safetensors",
|
120 |
+
is_guided=True,
|
121 |
+
),
|
122 |
+
ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
|
123 |
+
height=1024,
|
124 |
+
width=1024,
|
125 |
+
channels=4,
|
126 |
+
factor=8,
|
127 |
+
is_legacy=True,
|
128 |
+
config="sd_xl_refiner.yaml",
|
129 |
+
ckpt="sd_xl_refiner_0.9.safetensors",
|
130 |
+
is_guided=True,
|
131 |
+
),
|
132 |
+
ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
|
133 |
+
height=1024,
|
134 |
+
width=1024,
|
135 |
+
channels=4,
|
136 |
+
factor=8,
|
137 |
+
is_legacy=False,
|
138 |
+
config="sd_xl_base.yaml",
|
139 |
+
ckpt="sd_xl_base_1.0.safetensors",
|
140 |
+
is_guided=True,
|
141 |
+
),
|
142 |
+
ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
|
143 |
+
height=1024,
|
144 |
+
width=1024,
|
145 |
+
channels=4,
|
146 |
+
factor=8,
|
147 |
+
is_legacy=True,
|
148 |
+
config="sd_xl_refiner.yaml",
|
149 |
+
ckpt="sd_xl_refiner_1.0.safetensors",
|
150 |
+
is_guided=True,
|
151 |
+
),
|
152 |
+
}
|
153 |
+
|
154 |
+
|
155 |
+
class SamplingPipeline:
|
156 |
+
def __init__(
|
157 |
+
self,
|
158 |
+
model_id: ModelArchitecture,
|
159 |
+
model_path="checkpoints",
|
160 |
+
config_path="configs/inference",
|
161 |
+
device="cuda",
|
162 |
+
use_fp16=True,
|
163 |
+
) -> None:
|
164 |
+
if model_id not in model_specs:
|
165 |
+
raise ValueError(f"Model {model_id} not supported")
|
166 |
+
self.model_id = model_id
|
167 |
+
self.specs = model_specs[self.model_id]
|
168 |
+
self.config = str(pathlib.Path(config_path, self.specs.config))
|
169 |
+
self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
|
170 |
+
self.device = device
|
171 |
+
self.model = self._load_model(device=device, use_fp16=use_fp16)
|
172 |
+
|
173 |
+
def _load_model(self, device="cuda", use_fp16=True):
|
174 |
+
config = OmegaConf.load(self.config)
|
175 |
+
model = load_model_from_config(config, self.ckpt)
|
176 |
+
if model is None:
|
177 |
+
raise ValueError(f"Model {self.model_id} could not be loaded")
|
178 |
+
model.to(device)
|
179 |
+
if use_fp16:
|
180 |
+
model.conditioner.half()
|
181 |
+
model.model.half()
|
182 |
+
return model
|
183 |
+
|
184 |
+
def text_to_image(
|
185 |
+
self,
|
186 |
+
params: SamplingParams,
|
187 |
+
prompt: str,
|
188 |
+
negative_prompt: str = "",
|
189 |
+
samples: int = 1,
|
190 |
+
return_latents: bool = False,
|
191 |
+
):
|
192 |
+
sampler = get_sampler_config(params)
|
193 |
+
value_dict = asdict(params)
|
194 |
+
value_dict["prompt"] = prompt
|
195 |
+
value_dict["negative_prompt"] = negative_prompt
|
196 |
+
value_dict["target_width"] = params.width
|
197 |
+
value_dict["target_height"] = params.height
|
198 |
+
return do_sample(
|
199 |
+
self.model,
|
200 |
+
sampler,
|
201 |
+
value_dict,
|
202 |
+
samples,
|
203 |
+
params.height,
|
204 |
+
params.width,
|
205 |
+
self.specs.channels,
|
206 |
+
self.specs.factor,
|
207 |
+
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
208 |
+
return_latents=return_latents,
|
209 |
+
filter=None,
|
210 |
+
)
|
211 |
+
|
212 |
+
def image_to_image(
|
213 |
+
self,
|
214 |
+
params: SamplingParams,
|
215 |
+
image,
|
216 |
+
prompt: str,
|
217 |
+
negative_prompt: str = "",
|
218 |
+
samples: int = 1,
|
219 |
+
return_latents: bool = False,
|
220 |
+
):
|
221 |
+
sampler = get_sampler_config(params)
|
222 |
+
|
223 |
+
if params.img2img_strength < 1.0:
|
224 |
+
sampler.discretization = Img2ImgDiscretizationWrapper(
|
225 |
+
sampler.discretization,
|
226 |
+
strength=params.img2img_strength,
|
227 |
+
)
|
228 |
+
height, width = image.shape[2], image.shape[3]
|
229 |
+
value_dict = asdict(params)
|
230 |
+
value_dict["prompt"] = prompt
|
231 |
+
value_dict["negative_prompt"] = negative_prompt
|
232 |
+
value_dict["target_width"] = width
|
233 |
+
value_dict["target_height"] = height
|
234 |
+
return do_img2img(
|
235 |
+
image,
|
236 |
+
self.model,
|
237 |
+
sampler,
|
238 |
+
value_dict,
|
239 |
+
samples,
|
240 |
+
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
241 |
+
return_latents=return_latents,
|
242 |
+
filter=None,
|
243 |
+
)
|
244 |
+
|
245 |
+
def refiner(
|
246 |
+
self,
|
247 |
+
params: SamplingParams,
|
248 |
+
image,
|
249 |
+
prompt: str,
|
250 |
+
negative_prompt: Optional[str] = None,
|
251 |
+
samples: int = 1,
|
252 |
+
return_latents: bool = False,
|
253 |
+
):
|
254 |
+
sampler = get_sampler_config(params)
|
255 |
+
value_dict = {
|
256 |
+
"orig_width": image.shape[3] * 8,
|
257 |
+
"orig_height": image.shape[2] * 8,
|
258 |
+
"target_width": image.shape[3] * 8,
|
259 |
+
"target_height": image.shape[2] * 8,
|
260 |
+
"prompt": prompt,
|
261 |
+
"negative_prompt": negative_prompt,
|
262 |
+
"crop_coords_top": 0,
|
263 |
+
"crop_coords_left": 0,
|
264 |
+
"aesthetic_score": 6.0,
|
265 |
+
"negative_aesthetic_score": 2.5,
|
266 |
+
}
|
267 |
+
|
268 |
+
return do_img2img(
|
269 |
+
image,
|
270 |
+
self.model,
|
271 |
+
sampler,
|
272 |
+
value_dict,
|
273 |
+
samples,
|
274 |
+
skip_encode=True,
|
275 |
+
return_latents=return_latents,
|
276 |
+
filter=None,
|
277 |
+
)
|
278 |
+
|
279 |
+
|
280 |
+
def get_guider_config(params: SamplingParams):
|
281 |
+
if params.guider == Guider.IDENTITY:
|
282 |
+
guider_config = {
|
283 |
+
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
|
284 |
+
}
|
285 |
+
elif params.guider == Guider.VANILLA:
|
286 |
+
scale = params.scale
|
287 |
+
|
288 |
+
thresholder = params.thresholder
|
289 |
+
|
290 |
+
if thresholder == Thresholder.NONE:
|
291 |
+
dyn_thresh_config = {
|
292 |
+
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
|
293 |
+
}
|
294 |
+
else:
|
295 |
+
raise NotImplementedError
|
296 |
+
|
297 |
+
guider_config = {
|
298 |
+
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
|
299 |
+
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
|
300 |
+
}
|
301 |
+
else:
|
302 |
+
raise NotImplementedError
|
303 |
+
return guider_config
|
304 |
+
|
305 |
+
|
306 |
+
def get_discretization_config(params: SamplingParams):
|
307 |
+
if params.discretization == Discretization.LEGACY_DDPM:
|
308 |
+
discretization_config = {
|
309 |
+
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
|
310 |
+
}
|
311 |
+
elif params.discretization == Discretization.EDM:
|
312 |
+
discretization_config = {
|
313 |
+
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
|
314 |
+
"params": {
|
315 |
+
"sigma_min": params.sigma_min,
|
316 |
+
"sigma_max": params.sigma_max,
|
317 |
+
"rho": params.rho,
|
318 |
+
},
|
319 |
+
}
|
320 |
+
else:
|
321 |
+
raise ValueError(f"unknown discretization {params.discretization}")
|
322 |
+
return discretization_config
|
323 |
+
|
324 |
+
|
325 |
+
def get_sampler_config(params: SamplingParams):
|
326 |
+
discretization_config = get_discretization_config(params)
|
327 |
+
guider_config = get_guider_config(params)
|
328 |
+
sampler = None
|
329 |
+
if params.sampler == Sampler.EULER_EDM:
|
330 |
+
return EulerEDMSampler(
|
331 |
+
num_steps=params.steps,
|
332 |
+
discretization_config=discretization_config,
|
333 |
+
guider_config=guider_config,
|
334 |
+
s_churn=params.s_churn,
|
335 |
+
s_tmin=params.s_tmin,
|
336 |
+
s_tmax=params.s_tmax,
|
337 |
+
s_noise=params.s_noise,
|
338 |
+
verbose=True,
|
339 |
+
)
|
340 |
+
if params.sampler == Sampler.HEUN_EDM:
|
341 |
+
return HeunEDMSampler(
|
342 |
+
num_steps=params.steps,
|
343 |
+
discretization_config=discretization_config,
|
344 |
+
guider_config=guider_config,
|
345 |
+
s_churn=params.s_churn,
|
346 |
+
s_tmin=params.s_tmin,
|
347 |
+
s_tmax=params.s_tmax,
|
348 |
+
s_noise=params.s_noise,
|
349 |
+
verbose=True,
|
350 |
+
)
|
351 |
+
if params.sampler == Sampler.EULER_ANCESTRAL:
|
352 |
+
return EulerAncestralSampler(
|
353 |
+
num_steps=params.steps,
|
354 |
+
discretization_config=discretization_config,
|
355 |
+
guider_config=guider_config,
|
356 |
+
eta=params.eta,
|
357 |
+
s_noise=params.s_noise,
|
358 |
+
verbose=True,
|
359 |
+
)
|
360 |
+
if params.sampler == Sampler.DPMPP2S_ANCESTRAL:
|
361 |
+
return DPMPP2SAncestralSampler(
|
362 |
+
num_steps=params.steps,
|
363 |
+
discretization_config=discretization_config,
|
364 |
+
guider_config=guider_config,
|
365 |
+
eta=params.eta,
|
366 |
+
s_noise=params.s_noise,
|
367 |
+
verbose=True,
|
368 |
+
)
|
369 |
+
if params.sampler == Sampler.DPMPP2M:
|
370 |
+
return DPMPP2MSampler(
|
371 |
+
num_steps=params.steps,
|
372 |
+
discretization_config=discretization_config,
|
373 |
+
guider_config=guider_config,
|
374 |
+
verbose=True,
|
375 |
+
)
|
376 |
+
if params.sampler == Sampler.LINEAR_MULTISTEP:
|
377 |
+
return LinearMultistepSampler(
|
378 |
+
num_steps=params.steps,
|
379 |
+
discretization_config=discretization_config,
|
380 |
+
guider_config=guider_config,
|
381 |
+
order=params.order,
|
382 |
+
verbose=True,
|
383 |
+
)
|
384 |
+
|
385 |
+
raise ValueError(f"unknown sampler {params.sampler}!")
|
sgm/inference/helpers.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
from typing import List, Optional, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from einops import rearrange
|
8 |
+
from imwatermark import WatermarkEncoder
|
9 |
+
from omegaconf import ListConfig
|
10 |
+
from PIL import Image
|
11 |
+
from torch import autocast
|
12 |
+
|
13 |
+
from sgm.util import append_dims
|
14 |
+
|
15 |
+
|
16 |
+
class WatermarkEmbedder:
|
17 |
+
def __init__(self, watermark):
|
18 |
+
self.watermark = watermark
|
19 |
+
self.num_bits = len(WATERMARK_BITS)
|
20 |
+
self.encoder = WatermarkEncoder()
|
21 |
+
self.encoder.set_watermark("bits", self.watermark)
|
22 |
+
|
23 |
+
def __call__(self, image: torch.Tensor) -> torch.Tensor:
|
24 |
+
"""
|
25 |
+
Adds a predefined watermark to the input image
|
26 |
+
|
27 |
+
Args:
|
28 |
+
image: ([N,] B, RGB, H, W) in range [0, 1]
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
same as input but watermarked
|
32 |
+
"""
|
33 |
+
squeeze = len(image.shape) == 4
|
34 |
+
if squeeze:
|
35 |
+
image = image[None, ...]
|
36 |
+
n = image.shape[0]
|
37 |
+
image_np = rearrange(
|
38 |
+
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
|
39 |
+
).numpy()[:, :, :, ::-1]
|
40 |
+
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
41 |
+
# watermarking libary expects input as cv2 BGR format
|
42 |
+
for k in range(image_np.shape[0]):
|
43 |
+
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
44 |
+
image = torch.from_numpy(
|
45 |
+
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
|
46 |
+
).to(image.device)
|
47 |
+
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
48 |
+
if squeeze:
|
49 |
+
image = image[0]
|
50 |
+
return image
|
51 |
+
|
52 |
+
|
53 |
+
# A fixed 48-bit message that was choosen at random
|
54 |
+
# WATERMARK_MESSAGE = 0xB3EC907BB19E
|
55 |
+
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
|
56 |
+
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
57 |
+
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
58 |
+
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
|
59 |
+
|
60 |
+
|
61 |
+
def get_unique_embedder_keys_from_conditioner(conditioner):
|
62 |
+
return list({x.input_key for x in conditioner.embedders})
|
63 |
+
|
64 |
+
|
65 |
+
def perform_save_locally(save_path, samples):
|
66 |
+
os.makedirs(os.path.join(save_path), exist_ok=True)
|
67 |
+
base_count = len(os.listdir(os.path.join(save_path)))
|
68 |
+
samples = embed_watermark(samples)
|
69 |
+
for sample in samples:
|
70 |
+
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
|
71 |
+
Image.fromarray(sample.astype(np.uint8)).save(
|
72 |
+
os.path.join(save_path, f"{base_count:09}.png")
|
73 |
+
)
|
74 |
+
base_count += 1
|
75 |
+
|
76 |
+
|
77 |
+
class Img2ImgDiscretizationWrapper:
|
78 |
+
"""
|
79 |
+
wraps a discretizer, and prunes the sigmas
|
80 |
+
params:
|
81 |
+
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
|
82 |
+
"""
|
83 |
+
|
84 |
+
def __init__(self, discretization, strength: float = 1.0):
|
85 |
+
self.discretization = discretization
|
86 |
+
self.strength = strength
|
87 |
+
assert 0.0 <= self.strength <= 1.0
|
88 |
+
|
89 |
+
def __call__(self, *args, **kwargs):
|
90 |
+
# sigmas start large first, and decrease then
|
91 |
+
sigmas = self.discretization(*args, **kwargs)
|
92 |
+
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
93 |
+
sigmas = torch.flip(sigmas, (0,))
|
94 |
+
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
|
95 |
+
print("prune index:", max(int(self.strength * len(sigmas)), 1))
|
96 |
+
sigmas = torch.flip(sigmas, (0,))
|
97 |
+
print(f"sigmas after pruning: ", sigmas)
|
98 |
+
return sigmas
|
99 |
+
|
100 |
+
|
101 |
+
def do_sample(
|
102 |
+
model,
|
103 |
+
sampler,
|
104 |
+
value_dict,
|
105 |
+
num_samples,
|
106 |
+
H,
|
107 |
+
W,
|
108 |
+
C,
|
109 |
+
F,
|
110 |
+
force_uc_zero_embeddings: Optional[List] = None,
|
111 |
+
batch2model_input: Optional[List] = None,
|
112 |
+
return_latents=False,
|
113 |
+
filter=None,
|
114 |
+
device="cuda",
|
115 |
+
):
|
116 |
+
if force_uc_zero_embeddings is None:
|
117 |
+
force_uc_zero_embeddings = []
|
118 |
+
if batch2model_input is None:
|
119 |
+
batch2model_input = []
|
120 |
+
|
121 |
+
with torch.no_grad():
|
122 |
+
with autocast(device) as precision_scope:
|
123 |
+
with model.ema_scope():
|
124 |
+
num_samples = [num_samples]
|
125 |
+
batch, batch_uc = get_batch(
|
126 |
+
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
127 |
+
value_dict,
|
128 |
+
num_samples,
|
129 |
+
)
|
130 |
+
for key in batch:
|
131 |
+
if isinstance(batch[key], torch.Tensor):
|
132 |
+
print(key, batch[key].shape)
|
133 |
+
elif isinstance(batch[key], list):
|
134 |
+
print(key, [len(l) for l in batch[key]])
|
135 |
+
else:
|
136 |
+
print(key, batch[key])
|
137 |
+
c, uc = model.conditioner.get_unconditional_conditioning(
|
138 |
+
batch,
|
139 |
+
batch_uc=batch_uc,
|
140 |
+
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
141 |
+
)
|
142 |
+
|
143 |
+
for k in c:
|
144 |
+
if not k == "crossattn":
|
145 |
+
c[k], uc[k] = map(
|
146 |
+
lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
|
147 |
+
)
|
148 |
+
|
149 |
+
additional_model_inputs = {}
|
150 |
+
for k in batch2model_input:
|
151 |
+
additional_model_inputs[k] = batch[k]
|
152 |
+
|
153 |
+
shape = (math.prod(num_samples), C, H // F, W // F)
|
154 |
+
randn = torch.randn(shape).to(device)
|
155 |
+
|
156 |
+
def denoiser(input, sigma, c):
|
157 |
+
return model.denoiser(
|
158 |
+
model.model, input, sigma, c, **additional_model_inputs
|
159 |
+
)
|
160 |
+
|
161 |
+
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
162 |
+
samples_x = model.decode_first_stage(samples_z)
|
163 |
+
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
164 |
+
|
165 |
+
if filter is not None:
|
166 |
+
samples = filter(samples)
|
167 |
+
|
168 |
+
if return_latents:
|
169 |
+
return samples, samples_z
|
170 |
+
return samples
|
171 |
+
|
172 |
+
|
173 |
+
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
174 |
+
# Hardcoded demo setups; might undergo some changes in the future
|
175 |
+
|
176 |
+
batch = {}
|
177 |
+
batch_uc = {}
|
178 |
+
|
179 |
+
for key in keys:
|
180 |
+
if key == "txt":
|
181 |
+
batch["txt"] = (
|
182 |
+
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
|
183 |
+
.reshape(N)
|
184 |
+
.tolist()
|
185 |
+
)
|
186 |
+
batch_uc["txt"] = (
|
187 |
+
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
|
188 |
+
.reshape(N)
|
189 |
+
.tolist()
|
190 |
+
)
|
191 |
+
elif key == "original_size_as_tuple":
|
192 |
+
batch["original_size_as_tuple"] = (
|
193 |
+
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
194 |
+
.to(device)
|
195 |
+
.repeat(*N, 1)
|
196 |
+
)
|
197 |
+
elif key == "crop_coords_top_left":
|
198 |
+
batch["crop_coords_top_left"] = (
|
199 |
+
torch.tensor(
|
200 |
+
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
|
201 |
+
)
|
202 |
+
.to(device)
|
203 |
+
.repeat(*N, 1)
|
204 |
+
)
|
205 |
+
elif key == "aesthetic_score":
|
206 |
+
batch["aesthetic_score"] = (
|
207 |
+
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
208 |
+
)
|
209 |
+
batch_uc["aesthetic_score"] = (
|
210 |
+
torch.tensor([value_dict["negative_aesthetic_score"]])
|
211 |
+
.to(device)
|
212 |
+
.repeat(*N, 1)
|
213 |
+
)
|
214 |
+
|
215 |
+
elif key == "target_size_as_tuple":
|
216 |
+
batch["target_size_as_tuple"] = (
|
217 |
+
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
218 |
+
.to(device)
|
219 |
+
.repeat(*N, 1)
|
220 |
+
)
|
221 |
+
else:
|
222 |
+
batch[key] = value_dict[key]
|
223 |
+
|
224 |
+
for key in batch.keys():
|
225 |
+
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
226 |
+
batch_uc[key] = torch.clone(batch[key])
|
227 |
+
return batch, batch_uc
|
228 |
+
|
229 |
+
|
230 |
+
def get_input_image_tensor(image: Image.Image, device="cuda"):
|
231 |
+
w, h = image.size
|
232 |
+
print(f"loaded input image of size ({w}, {h})")
|
233 |
+
width, height = map(
|
234 |
+
lambda x: x - x % 64, (w, h)
|
235 |
+
) # resize to integer multiple of 64
|
236 |
+
image = image.resize((width, height))
|
237 |
+
image_array = np.array(image.convert("RGB"))
|
238 |
+
image_array = image_array[None].transpose(0, 3, 1, 2)
|
239 |
+
image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0
|
240 |
+
return image_tensor.to(device)
|
241 |
+
|
242 |
+
|
243 |
+
def do_img2img(
|
244 |
+
img,
|
245 |
+
model,
|
246 |
+
sampler,
|
247 |
+
value_dict,
|
248 |
+
num_samples,
|
249 |
+
force_uc_zero_embeddings=[],
|
250 |
+
additional_kwargs={},
|
251 |
+
offset_noise_level: float = 0.0,
|
252 |
+
return_latents=False,
|
253 |
+
skip_encode=False,
|
254 |
+
filter=None,
|
255 |
+
device="cuda",
|
256 |
+
):
|
257 |
+
with torch.no_grad():
|
258 |
+
with autocast(device) as precision_scope:
|
259 |
+
with model.ema_scope():
|
260 |
+
batch, batch_uc = get_batch(
|
261 |
+
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
262 |
+
value_dict,
|
263 |
+
[num_samples],
|
264 |
+
)
|
265 |
+
c, uc = model.conditioner.get_unconditional_conditioning(
|
266 |
+
batch,
|
267 |
+
batch_uc=batch_uc,
|
268 |
+
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
269 |
+
)
|
270 |
+
|
271 |
+
for k in c:
|
272 |
+
c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
|
273 |
+
|
274 |
+
for k in additional_kwargs:
|
275 |
+
c[k] = uc[k] = additional_kwargs[k]
|
276 |
+
if skip_encode:
|
277 |
+
z = img
|
278 |
+
else:
|
279 |
+
z = model.encode_first_stage(img)
|
280 |
+
noise = torch.randn_like(z)
|
281 |
+
sigmas = sampler.discretization(sampler.num_steps)
|
282 |
+
sigma = sigmas[0].to(z.device)
|
283 |
+
|
284 |
+
if offset_noise_level > 0.0:
|
285 |
+
noise = noise + offset_noise_level * append_dims(
|
286 |
+
torch.randn(z.shape[0], device=z.device), z.ndim
|
287 |
+
)
|
288 |
+
noised_z = z + noise * append_dims(sigma, z.ndim)
|
289 |
+
noised_z = noised_z / torch.sqrt(
|
290 |
+
1.0 + sigmas[0] ** 2.0
|
291 |
+
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
|
292 |
+
|
293 |
+
def denoiser(x, sigma, c):
|
294 |
+
return model.denoiser(model.model, x, sigma, c)
|
295 |
+
|
296 |
+
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
|
297 |
+
samples_x = model.decode_first_stage(samples_z)
|
298 |
+
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
299 |
+
|
300 |
+
if filter is not None:
|
301 |
+
samples = filter(samples)
|
302 |
+
|
303 |
+
if return_latents:
|
304 |
+
return samples, samples_z
|
305 |
+
return samples
|
sgm/lr_scheduler.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class LambdaWarmUpCosineScheduler:
|
5 |
+
"""
|
6 |
+
note: use with a base_lr of 1.0
|
7 |
+
"""
|
8 |
+
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
warm_up_steps,
|
12 |
+
lr_min,
|
13 |
+
lr_max,
|
14 |
+
lr_start,
|
15 |
+
max_decay_steps,
|
16 |
+
verbosity_interval=0,
|
17 |
+
):
|
18 |
+
self.lr_warm_up_steps = warm_up_steps
|
19 |
+
self.lr_start = lr_start
|
20 |
+
self.lr_min = lr_min
|
21 |
+
self.lr_max = lr_max
|
22 |
+
self.lr_max_decay_steps = max_decay_steps
|
23 |
+
self.last_lr = 0.0
|
24 |
+
self.verbosity_interval = verbosity_interval
|
25 |
+
|
26 |
+
def schedule(self, n, **kwargs):
|
27 |
+
if self.verbosity_interval > 0:
|
28 |
+
if n % self.verbosity_interval == 0:
|
29 |
+
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
30 |
+
if n < self.lr_warm_up_steps:
|
31 |
+
lr = (
|
32 |
+
self.lr_max - self.lr_start
|
33 |
+
) / self.lr_warm_up_steps * n + self.lr_start
|
34 |
+
self.last_lr = lr
|
35 |
+
return lr
|
36 |
+
else:
|
37 |
+
t = (n - self.lr_warm_up_steps) / (
|
38 |
+
self.lr_max_decay_steps - self.lr_warm_up_steps
|
39 |
+
)
|
40 |
+
t = min(t, 1.0)
|
41 |
+
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
42 |
+
1 + np.cos(t * np.pi)
|
43 |
+
)
|
44 |
+
self.last_lr = lr
|
45 |
+
return lr
|
46 |
+
|
47 |
+
def __call__(self, n, **kwargs):
|
48 |
+
return self.schedule(n, **kwargs)
|
49 |
+
|
50 |
+
|
51 |
+
class LambdaWarmUpCosineScheduler2:
|
52 |
+
"""
|
53 |
+
supports repeated iterations, configurable via lists
|
54 |
+
note: use with a base_lr of 1.0.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(
|
58 |
+
self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
|
59 |
+
):
|
60 |
+
assert (
|
61 |
+
len(warm_up_steps)
|
62 |
+
== len(f_min)
|
63 |
+
== len(f_max)
|
64 |
+
== len(f_start)
|
65 |
+
== len(cycle_lengths)
|
66 |
+
)
|
67 |
+
self.lr_warm_up_steps = warm_up_steps
|
68 |
+
self.f_start = f_start
|
69 |
+
self.f_min = f_min
|
70 |
+
self.f_max = f_max
|
71 |
+
self.cycle_lengths = cycle_lengths
|
72 |
+
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
73 |
+
self.last_f = 0.0
|
74 |
+
self.verbosity_interval = verbosity_interval
|
75 |
+
|
76 |
+
def find_in_interval(self, n):
|
77 |
+
interval = 0
|
78 |
+
for cl in self.cum_cycles[1:]:
|
79 |
+
if n <= cl:
|
80 |
+
return interval
|
81 |
+
interval += 1
|
82 |
+
|
83 |
+
def schedule(self, n, **kwargs):
|
84 |
+
cycle = self.find_in_interval(n)
|
85 |
+
n = n - self.cum_cycles[cycle]
|
86 |
+
if self.verbosity_interval > 0:
|
87 |
+
if n % self.verbosity_interval == 0:
|
88 |
+
print(
|
89 |
+
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
90 |
+
f"current cycle {cycle}"
|
91 |
+
)
|
92 |
+
if n < self.lr_warm_up_steps[cycle]:
|
93 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
94 |
+
cycle
|
95 |
+
] * n + self.f_start[cycle]
|
96 |
+
self.last_f = f
|
97 |
+
return f
|
98 |
+
else:
|
99 |
+
t = (n - self.lr_warm_up_steps[cycle]) / (
|
100 |
+
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
|
101 |
+
)
|
102 |
+
t = min(t, 1.0)
|
103 |
+
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
104 |
+
1 + np.cos(t * np.pi)
|
105 |
+
)
|
106 |
+
self.last_f = f
|
107 |
+
return f
|
108 |
+
|
109 |
+
def __call__(self, n, **kwargs):
|
110 |
+
return self.schedule(n, **kwargs)
|
111 |
+
|
112 |
+
|
113 |
+
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
114 |
+
def schedule(self, n, **kwargs):
|
115 |
+
cycle = self.find_in_interval(n)
|
116 |
+
n = n - self.cum_cycles[cycle]
|
117 |
+
if self.verbosity_interval > 0:
|
118 |
+
if n % self.verbosity_interval == 0:
|
119 |
+
print(
|
120 |
+
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
121 |
+
f"current cycle {cycle}"
|
122 |
+
)
|
123 |
+
|
124 |
+
if n < self.lr_warm_up_steps[cycle]:
|
125 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
126 |
+
cycle
|
127 |
+
] * n + self.f_start[cycle]
|
128 |
+
self.last_f = f
|
129 |
+
return f
|
130 |
+
else:
|
131 |
+
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
|
132 |
+
self.cycle_lengths[cycle] - n
|
133 |
+
) / (self.cycle_lengths[cycle])
|
134 |
+
self.last_f = f
|
135 |
+
return f
|
sgm/models/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .autoencoder import AutoencodingEngine
|
2 |
+
from .diffusion import DiffusionEngine
|
sgm/models/autoencoder.py
ADDED
@@ -0,0 +1,615 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
import re
|
4 |
+
from abc import abstractmethod
|
5 |
+
from contextlib import contextmanager
|
6 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
7 |
+
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from einops import rearrange
|
12 |
+
from packaging import version
|
13 |
+
|
14 |
+
from ..modules.autoencoding.regularizers import AbstractRegularizer
|
15 |
+
from ..modules.ema import LitEma
|
16 |
+
from ..util import (default, get_nested_attribute, get_obj_from_str,
|
17 |
+
instantiate_from_config)
|
18 |
+
|
19 |
+
logpy = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class AbstractAutoencoder(pl.LightningModule):
|
23 |
+
"""
|
24 |
+
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
|
25 |
+
unCLIP models, etc. Hence, it is fairly general, and specific features
|
26 |
+
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
ema_decay: Union[None, float] = None,
|
32 |
+
monitor: Union[None, str] = None,
|
33 |
+
input_key: str = "jpg",
|
34 |
+
):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
self.input_key = input_key
|
38 |
+
self.use_ema = ema_decay is not None
|
39 |
+
if monitor is not None:
|
40 |
+
self.monitor = monitor
|
41 |
+
|
42 |
+
if self.use_ema:
|
43 |
+
self.model_ema = LitEma(self, decay=ema_decay)
|
44 |
+
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
45 |
+
|
46 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
47 |
+
self.automatic_optimization = False
|
48 |
+
|
49 |
+
def apply_ckpt(self, ckpt: Union[None, str, dict]):
|
50 |
+
if ckpt is None:
|
51 |
+
return
|
52 |
+
if isinstance(ckpt, str):
|
53 |
+
ckpt = {
|
54 |
+
"target": "sgm.modules.checkpoint.CheckpointEngine",
|
55 |
+
"params": {"ckpt_path": ckpt},
|
56 |
+
}
|
57 |
+
engine = instantiate_from_config(ckpt)
|
58 |
+
engine(self)
|
59 |
+
|
60 |
+
@abstractmethod
|
61 |
+
def get_input(self, batch) -> Any:
|
62 |
+
raise NotImplementedError()
|
63 |
+
|
64 |
+
def on_train_batch_end(self, *args, **kwargs):
|
65 |
+
# for EMA computation
|
66 |
+
if self.use_ema:
|
67 |
+
self.model_ema(self)
|
68 |
+
|
69 |
+
@contextmanager
|
70 |
+
def ema_scope(self, context=None):
|
71 |
+
if self.use_ema:
|
72 |
+
self.model_ema.store(self.parameters())
|
73 |
+
self.model_ema.copy_to(self)
|
74 |
+
if context is not None:
|
75 |
+
logpy.info(f"{context}: Switched to EMA weights")
|
76 |
+
try:
|
77 |
+
yield None
|
78 |
+
finally:
|
79 |
+
if self.use_ema:
|
80 |
+
self.model_ema.restore(self.parameters())
|
81 |
+
if context is not None:
|
82 |
+
logpy.info(f"{context}: Restored training weights")
|
83 |
+
|
84 |
+
@abstractmethod
|
85 |
+
def encode(self, *args, **kwargs) -> torch.Tensor:
|
86 |
+
raise NotImplementedError("encode()-method of abstract base class called")
|
87 |
+
|
88 |
+
@abstractmethod
|
89 |
+
def decode(self, *args, **kwargs) -> torch.Tensor:
|
90 |
+
raise NotImplementedError("decode()-method of abstract base class called")
|
91 |
+
|
92 |
+
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
93 |
+
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
94 |
+
return get_obj_from_str(cfg["target"])(
|
95 |
+
params, lr=lr, **cfg.get("params", dict())
|
96 |
+
)
|
97 |
+
|
98 |
+
def configure_optimizers(self) -> Any:
|
99 |
+
raise NotImplementedError()
|
100 |
+
|
101 |
+
|
102 |
+
class AutoencodingEngine(AbstractAutoencoder):
|
103 |
+
"""
|
104 |
+
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
|
105 |
+
(we also restore them explicitly as special cases for legacy reasons).
|
106 |
+
Regularizations such as KL or VQ are moved to the regularizer class.
|
107 |
+
"""
|
108 |
+
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
*args,
|
112 |
+
encoder_config: Dict,
|
113 |
+
decoder_config: Dict,
|
114 |
+
loss_config: Dict,
|
115 |
+
regularizer_config: Dict,
|
116 |
+
optimizer_config: Union[Dict, None] = None,
|
117 |
+
lr_g_factor: float = 1.0,
|
118 |
+
trainable_ae_params: Optional[List[List[str]]] = None,
|
119 |
+
ae_optimizer_args: Optional[List[dict]] = None,
|
120 |
+
trainable_disc_params: Optional[List[List[str]]] = None,
|
121 |
+
disc_optimizer_args: Optional[List[dict]] = None,
|
122 |
+
disc_start_iter: int = 0,
|
123 |
+
diff_boost_factor: float = 3.0,
|
124 |
+
ckpt_engine: Union[None, str, dict] = None,
|
125 |
+
ckpt_path: Optional[str] = None,
|
126 |
+
additional_decode_keys: Optional[List[str]] = None,
|
127 |
+
**kwargs,
|
128 |
+
):
|
129 |
+
super().__init__(*args, **kwargs)
|
130 |
+
self.automatic_optimization = False # pytorch lightning
|
131 |
+
|
132 |
+
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
133 |
+
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
134 |
+
self.loss: torch.nn.Module = instantiate_from_config(loss_config)
|
135 |
+
self.regularization: AbstractRegularizer = instantiate_from_config(
|
136 |
+
regularizer_config
|
137 |
+
)
|
138 |
+
self.optimizer_config = default(
|
139 |
+
optimizer_config, {"target": "torch.optim.Adam"}
|
140 |
+
)
|
141 |
+
self.diff_boost_factor = diff_boost_factor
|
142 |
+
self.disc_start_iter = disc_start_iter
|
143 |
+
self.lr_g_factor = lr_g_factor
|
144 |
+
self.trainable_ae_params = trainable_ae_params
|
145 |
+
if self.trainable_ae_params is not None:
|
146 |
+
self.ae_optimizer_args = default(
|
147 |
+
ae_optimizer_args,
|
148 |
+
[{} for _ in range(len(self.trainable_ae_params))],
|
149 |
+
)
|
150 |
+
assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
|
151 |
+
else:
|
152 |
+
self.ae_optimizer_args = [{}] # makes type consitent
|
153 |
+
|
154 |
+
self.trainable_disc_params = trainable_disc_params
|
155 |
+
if self.trainable_disc_params is not None:
|
156 |
+
self.disc_optimizer_args = default(
|
157 |
+
disc_optimizer_args,
|
158 |
+
[{} for _ in range(len(self.trainable_disc_params))],
|
159 |
+
)
|
160 |
+
assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
|
161 |
+
else:
|
162 |
+
self.disc_optimizer_args = [{}] # makes type consitent
|
163 |
+
|
164 |
+
if ckpt_path is not None:
|
165 |
+
assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
|
166 |
+
logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
|
167 |
+
self.apply_ckpt(default(ckpt_path, ckpt_engine))
|
168 |
+
self.additional_decode_keys = set(default(additional_decode_keys, []))
|
169 |
+
|
170 |
+
def get_input(self, batch: Dict) -> torch.Tensor:
|
171 |
+
# assuming unified data format, dataloader returns a dict.
|
172 |
+
# image tensors should be scaled to -1 ... 1 and in channels-first
|
173 |
+
# format (e.g., bchw instead if bhwc)
|
174 |
+
return batch[self.input_key]
|
175 |
+
|
176 |
+
def get_autoencoder_params(self) -> list:
|
177 |
+
params = []
|
178 |
+
if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
|
179 |
+
params += list(self.loss.get_trainable_autoencoder_parameters())
|
180 |
+
if hasattr(self.regularization, "get_trainable_parameters"):
|
181 |
+
params += list(self.regularization.get_trainable_parameters())
|
182 |
+
params = params + list(self.encoder.parameters())
|
183 |
+
params = params + list(self.decoder.parameters())
|
184 |
+
return params
|
185 |
+
|
186 |
+
def get_discriminator_params(self) -> list:
|
187 |
+
if hasattr(self.loss, "get_trainable_parameters"):
|
188 |
+
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
|
189 |
+
else:
|
190 |
+
params = []
|
191 |
+
return params
|
192 |
+
|
193 |
+
def get_last_layer(self):
|
194 |
+
return self.decoder.get_last_layer()
|
195 |
+
|
196 |
+
def encode(
|
197 |
+
self,
|
198 |
+
x: torch.Tensor,
|
199 |
+
return_reg_log: bool = False,
|
200 |
+
unregularized: bool = False,
|
201 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
202 |
+
z = self.encoder(x)
|
203 |
+
if unregularized:
|
204 |
+
return z, dict()
|
205 |
+
z, reg_log = self.regularization(z)
|
206 |
+
if return_reg_log:
|
207 |
+
return z, reg_log
|
208 |
+
return z
|
209 |
+
|
210 |
+
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
|
211 |
+
x = self.decoder(z, **kwargs)
|
212 |
+
return x
|
213 |
+
|
214 |
+
def forward(
|
215 |
+
self, x: torch.Tensor, **additional_decode_kwargs
|
216 |
+
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
217 |
+
z, reg_log = self.encode(x, return_reg_log=True)
|
218 |
+
dec = self.decode(z, **additional_decode_kwargs)
|
219 |
+
return z, dec, reg_log
|
220 |
+
|
221 |
+
def inner_training_step(
|
222 |
+
self, batch: dict, batch_idx: int, optimizer_idx: int = 0
|
223 |
+
) -> torch.Tensor:
|
224 |
+
x = self.get_input(batch)
|
225 |
+
additional_decode_kwargs = {
|
226 |
+
key: batch[key] for key in self.additional_decode_keys.intersection(batch)
|
227 |
+
}
|
228 |
+
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
|
229 |
+
if hasattr(self.loss, "forward_keys"):
|
230 |
+
extra_info = {
|
231 |
+
"z": z,
|
232 |
+
"optimizer_idx": optimizer_idx,
|
233 |
+
"global_step": self.global_step,
|
234 |
+
"last_layer": self.get_last_layer(),
|
235 |
+
"split": "train",
|
236 |
+
"regularization_log": regularization_log,
|
237 |
+
"autoencoder": self,
|
238 |
+
}
|
239 |
+
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
|
240 |
+
else:
|
241 |
+
extra_info = dict()
|
242 |
+
|
243 |
+
if optimizer_idx == 0:
|
244 |
+
# autoencode
|
245 |
+
out_loss = self.loss(x, xrec, **extra_info)
|
246 |
+
if isinstance(out_loss, tuple):
|
247 |
+
aeloss, log_dict_ae = out_loss
|
248 |
+
else:
|
249 |
+
# simple loss function
|
250 |
+
aeloss = out_loss
|
251 |
+
log_dict_ae = {"train/loss/rec": aeloss.detach()}
|
252 |
+
|
253 |
+
self.log_dict(
|
254 |
+
log_dict_ae,
|
255 |
+
prog_bar=False,
|
256 |
+
logger=True,
|
257 |
+
on_step=True,
|
258 |
+
on_epoch=True,
|
259 |
+
sync_dist=False,
|
260 |
+
)
|
261 |
+
self.log(
|
262 |
+
"loss",
|
263 |
+
aeloss.mean().detach(),
|
264 |
+
prog_bar=True,
|
265 |
+
logger=False,
|
266 |
+
on_epoch=False,
|
267 |
+
on_step=True,
|
268 |
+
)
|
269 |
+
return aeloss
|
270 |
+
elif optimizer_idx == 1:
|
271 |
+
# discriminator
|
272 |
+
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
273 |
+
# -> discriminator always needs to return a tuple
|
274 |
+
self.log_dict(
|
275 |
+
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
276 |
+
)
|
277 |
+
return discloss
|
278 |
+
else:
|
279 |
+
raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
|
280 |
+
|
281 |
+
def training_step(self, batch: dict, batch_idx: int):
|
282 |
+
opts = self.optimizers()
|
283 |
+
if not isinstance(opts, list):
|
284 |
+
# Non-adversarial case
|
285 |
+
opts = [opts]
|
286 |
+
optimizer_idx = batch_idx % len(opts)
|
287 |
+
if self.global_step < self.disc_start_iter:
|
288 |
+
optimizer_idx = 0
|
289 |
+
opt = opts[optimizer_idx]
|
290 |
+
opt.zero_grad()
|
291 |
+
with opt.toggle_model():
|
292 |
+
loss = self.inner_training_step(
|
293 |
+
batch, batch_idx, optimizer_idx=optimizer_idx
|
294 |
+
)
|
295 |
+
self.manual_backward(loss)
|
296 |
+
opt.step()
|
297 |
+
|
298 |
+
def validation_step(self, batch: dict, batch_idx: int) -> Dict:
|
299 |
+
log_dict = self._validation_step(batch, batch_idx)
|
300 |
+
with self.ema_scope():
|
301 |
+
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
302 |
+
log_dict.update(log_dict_ema)
|
303 |
+
return log_dict
|
304 |
+
|
305 |
+
def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
|
306 |
+
x = self.get_input(batch)
|
307 |
+
|
308 |
+
z, xrec, regularization_log = self(x)
|
309 |
+
if hasattr(self.loss, "forward_keys"):
|
310 |
+
extra_info = {
|
311 |
+
"z": z,
|
312 |
+
"optimizer_idx": 0,
|
313 |
+
"global_step": self.global_step,
|
314 |
+
"last_layer": self.get_last_layer(),
|
315 |
+
"split": "val" + postfix,
|
316 |
+
"regularization_log": regularization_log,
|
317 |
+
"autoencoder": self,
|
318 |
+
}
|
319 |
+
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
|
320 |
+
else:
|
321 |
+
extra_info = dict()
|
322 |
+
out_loss = self.loss(x, xrec, **extra_info)
|
323 |
+
if isinstance(out_loss, tuple):
|
324 |
+
aeloss, log_dict_ae = out_loss
|
325 |
+
else:
|
326 |
+
# simple loss function
|
327 |
+
aeloss = out_loss
|
328 |
+
log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
|
329 |
+
full_log_dict = log_dict_ae
|
330 |
+
|
331 |
+
if "optimizer_idx" in extra_info:
|
332 |
+
extra_info["optimizer_idx"] = 1
|
333 |
+
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
334 |
+
full_log_dict.update(log_dict_disc)
|
335 |
+
self.log(
|
336 |
+
f"val{postfix}/loss/rec",
|
337 |
+
log_dict_ae[f"val{postfix}/loss/rec"],
|
338 |
+
sync_dist=True,
|
339 |
+
)
|
340 |
+
self.log_dict(full_log_dict, sync_dist=True)
|
341 |
+
return full_log_dict
|
342 |
+
|
343 |
+
def get_param_groups(
|
344 |
+
self, parameter_names: List[List[str]], optimizer_args: List[dict]
|
345 |
+
) -> Tuple[List[Dict[str, Any]], int]:
|
346 |
+
groups = []
|
347 |
+
num_params = 0
|
348 |
+
for names, args in zip(parameter_names, optimizer_args):
|
349 |
+
params = []
|
350 |
+
for pattern_ in names:
|
351 |
+
pattern_params = []
|
352 |
+
pattern = re.compile(pattern_)
|
353 |
+
for p_name, param in self.named_parameters():
|
354 |
+
if re.match(pattern, p_name):
|
355 |
+
pattern_params.append(param)
|
356 |
+
num_params += param.numel()
|
357 |
+
if len(pattern_params) == 0:
|
358 |
+
logpy.warn(f"Did not find parameters for pattern {pattern_}")
|
359 |
+
params.extend(pattern_params)
|
360 |
+
groups.append({"params": params, **args})
|
361 |
+
return groups, num_params
|
362 |
+
|
363 |
+
def configure_optimizers(self) -> List[torch.optim.Optimizer]:
|
364 |
+
if self.trainable_ae_params is None:
|
365 |
+
ae_params = self.get_autoencoder_params()
|
366 |
+
else:
|
367 |
+
ae_params, num_ae_params = self.get_param_groups(
|
368 |
+
self.trainable_ae_params, self.ae_optimizer_args
|
369 |
+
)
|
370 |
+
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
|
371 |
+
if self.trainable_disc_params is None:
|
372 |
+
disc_params = self.get_discriminator_params()
|
373 |
+
else:
|
374 |
+
disc_params, num_disc_params = self.get_param_groups(
|
375 |
+
self.trainable_disc_params, self.disc_optimizer_args
|
376 |
+
)
|
377 |
+
logpy.info(
|
378 |
+
f"Number of trainable discriminator parameters: {num_disc_params:,}"
|
379 |
+
)
|
380 |
+
opt_ae = self.instantiate_optimizer_from_config(
|
381 |
+
ae_params,
|
382 |
+
default(self.lr_g_factor, 1.0) * self.learning_rate,
|
383 |
+
self.optimizer_config,
|
384 |
+
)
|
385 |
+
opts = [opt_ae]
|
386 |
+
if len(disc_params) > 0:
|
387 |
+
opt_disc = self.instantiate_optimizer_from_config(
|
388 |
+
disc_params, self.learning_rate, self.optimizer_config
|
389 |
+
)
|
390 |
+
opts.append(opt_disc)
|
391 |
+
|
392 |
+
return opts
|
393 |
+
|
394 |
+
@torch.no_grad()
|
395 |
+
def log_images(
|
396 |
+
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
|
397 |
+
) -> dict:
|
398 |
+
log = dict()
|
399 |
+
additional_decode_kwargs = {}
|
400 |
+
x = self.get_input(batch)
|
401 |
+
additional_decode_kwargs.update(
|
402 |
+
{key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
403 |
+
)
|
404 |
+
|
405 |
+
_, xrec, _ = self(x, **additional_decode_kwargs)
|
406 |
+
log["inputs"] = x
|
407 |
+
log["reconstructions"] = xrec
|
408 |
+
diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
|
409 |
+
diff.clamp_(0, 1.0)
|
410 |
+
log["diff"] = 2.0 * diff - 1.0
|
411 |
+
# diff_boost shows location of small errors, by boosting their
|
412 |
+
# brightness.
|
413 |
+
log["diff_boost"] = (
|
414 |
+
2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
|
415 |
+
)
|
416 |
+
if hasattr(self.loss, "log_images"):
|
417 |
+
log.update(self.loss.log_images(x, xrec))
|
418 |
+
with self.ema_scope():
|
419 |
+
_, xrec_ema, _ = self(x, **additional_decode_kwargs)
|
420 |
+
log["reconstructions_ema"] = xrec_ema
|
421 |
+
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
|
422 |
+
diff_ema.clamp_(0, 1.0)
|
423 |
+
log["diff_ema"] = 2.0 * diff_ema - 1.0
|
424 |
+
log["diff_boost_ema"] = (
|
425 |
+
2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
426 |
+
)
|
427 |
+
if additional_log_kwargs:
|
428 |
+
additional_decode_kwargs.update(additional_log_kwargs)
|
429 |
+
_, xrec_add, _ = self(x, **additional_decode_kwargs)
|
430 |
+
log_str = "reconstructions-" + "-".join(
|
431 |
+
[f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
|
432 |
+
)
|
433 |
+
log[log_str] = xrec_add
|
434 |
+
return log
|
435 |
+
|
436 |
+
|
437 |
+
class AutoencodingEngineLegacy(AutoencodingEngine):
|
438 |
+
def __init__(self, embed_dim: int, **kwargs):
|
439 |
+
self.max_batch_size = kwargs.pop("max_batch_size", None)
|
440 |
+
ddconfig = kwargs.pop("ddconfig")
|
441 |
+
ckpt_path = kwargs.pop("ckpt_path", None)
|
442 |
+
ckpt_engine = kwargs.pop("ckpt_engine", None)
|
443 |
+
super().__init__(
|
444 |
+
encoder_config={
|
445 |
+
"target": "sgm.modules.diffusionmodules.model.Encoder",
|
446 |
+
"params": ddconfig,
|
447 |
+
},
|
448 |
+
decoder_config={
|
449 |
+
"target": "sgm.modules.diffusionmodules.model.Decoder",
|
450 |
+
"params": ddconfig,
|
451 |
+
},
|
452 |
+
**kwargs,
|
453 |
+
)
|
454 |
+
self.quant_conv = torch.nn.Conv2d(
|
455 |
+
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
|
456 |
+
(1 + ddconfig["double_z"]) * embed_dim,
|
457 |
+
1,
|
458 |
+
)
|
459 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
460 |
+
self.embed_dim = embed_dim
|
461 |
+
|
462 |
+
self.apply_ckpt(default(ckpt_path, ckpt_engine))
|
463 |
+
|
464 |
+
def get_autoencoder_params(self) -> list:
|
465 |
+
params = super().get_autoencoder_params()
|
466 |
+
return params
|
467 |
+
|
468 |
+
def encode(
|
469 |
+
self, x: torch.Tensor, return_reg_log: bool = False
|
470 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
471 |
+
if self.max_batch_size is None:
|
472 |
+
z = self.encoder(x)
|
473 |
+
z = self.quant_conv(z)
|
474 |
+
else:
|
475 |
+
N = x.shape[0]
|
476 |
+
bs = self.max_batch_size
|
477 |
+
n_batches = int(math.ceil(N / bs))
|
478 |
+
z = list()
|
479 |
+
for i_batch in range(n_batches):
|
480 |
+
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
|
481 |
+
z_batch = self.quant_conv(z_batch)
|
482 |
+
z.append(z_batch)
|
483 |
+
z = torch.cat(z, 0)
|
484 |
+
|
485 |
+
z, reg_log = self.regularization(z)
|
486 |
+
if return_reg_log:
|
487 |
+
return z, reg_log
|
488 |
+
return z
|
489 |
+
|
490 |
+
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
|
491 |
+
if self.max_batch_size is None:
|
492 |
+
dec = self.post_quant_conv(z)
|
493 |
+
dec = self.decoder(dec, **decoder_kwargs)
|
494 |
+
else:
|
495 |
+
N = z.shape[0]
|
496 |
+
bs = self.max_batch_size
|
497 |
+
n_batches = int(math.ceil(N / bs))
|
498 |
+
dec = list()
|
499 |
+
for i_batch in range(n_batches):
|
500 |
+
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
|
501 |
+
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
|
502 |
+
dec.append(dec_batch)
|
503 |
+
dec = torch.cat(dec, 0)
|
504 |
+
|
505 |
+
return dec
|
506 |
+
|
507 |
+
|
508 |
+
class AutoencoderKL(AutoencodingEngineLegacy):
|
509 |
+
def __init__(self, **kwargs):
|
510 |
+
if "lossconfig" in kwargs:
|
511 |
+
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
512 |
+
super().__init__(
|
513 |
+
regularizer_config={
|
514 |
+
"target": (
|
515 |
+
"sgm.modules.autoencoding.regularizers"
|
516 |
+
".DiagonalGaussianRegularizer"
|
517 |
+
)
|
518 |
+
},
|
519 |
+
**kwargs,
|
520 |
+
)
|
521 |
+
|
522 |
+
|
523 |
+
class AutoencoderLegacyVQ(AutoencodingEngineLegacy):
|
524 |
+
def __init__(
|
525 |
+
self,
|
526 |
+
embed_dim: int,
|
527 |
+
n_embed: int,
|
528 |
+
sane_index_shape: bool = False,
|
529 |
+
**kwargs,
|
530 |
+
):
|
531 |
+
if "lossconfig" in kwargs:
|
532 |
+
logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.")
|
533 |
+
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
534 |
+
super().__init__(
|
535 |
+
regularizer_config={
|
536 |
+
"target": (
|
537 |
+
"sgm.modules.autoencoding.regularizers.quantize" ".VectorQuantizer"
|
538 |
+
),
|
539 |
+
"params": {
|
540 |
+
"n_e": n_embed,
|
541 |
+
"e_dim": embed_dim,
|
542 |
+
"sane_index_shape": sane_index_shape,
|
543 |
+
},
|
544 |
+
},
|
545 |
+
**kwargs,
|
546 |
+
)
|
547 |
+
|
548 |
+
|
549 |
+
class IdentityFirstStage(AbstractAutoencoder):
|
550 |
+
def __init__(self, *args, **kwargs):
|
551 |
+
super().__init__(*args, **kwargs)
|
552 |
+
|
553 |
+
def get_input(self, x: Any) -> Any:
|
554 |
+
return x
|
555 |
+
|
556 |
+
def encode(self, x: Any, *args, **kwargs) -> Any:
|
557 |
+
return x
|
558 |
+
|
559 |
+
def decode(self, x: Any, *args, **kwargs) -> Any:
|
560 |
+
return x
|
561 |
+
|
562 |
+
|
563 |
+
class AEIntegerWrapper(nn.Module):
|
564 |
+
def __init__(
|
565 |
+
self,
|
566 |
+
model: nn.Module,
|
567 |
+
shape: Union[None, Tuple[int, int], List[int]] = (16, 16),
|
568 |
+
regularization_key: str = "regularization",
|
569 |
+
encoder_kwargs: Optional[Dict[str, Any]] = None,
|
570 |
+
):
|
571 |
+
super().__init__()
|
572 |
+
self.model = model
|
573 |
+
assert hasattr(model, "encode") and hasattr(
|
574 |
+
model, "decode"
|
575 |
+
), "Need AE interface"
|
576 |
+
self.regularization = get_nested_attribute(model, regularization_key)
|
577 |
+
self.shape = shape
|
578 |
+
self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True})
|
579 |
+
|
580 |
+
def encode(self, x) -> torch.Tensor:
|
581 |
+
assert (
|
582 |
+
not self.training
|
583 |
+
), f"{self.__class__.__name__} only supports inference currently"
|
584 |
+
_, log = self.model.encode(x, **self.encoder_kwargs)
|
585 |
+
assert isinstance(log, dict)
|
586 |
+
inds = log["min_encoding_indices"]
|
587 |
+
return rearrange(inds, "b ... -> b (...)")
|
588 |
+
|
589 |
+
def decode(
|
590 |
+
self, inds: torch.Tensor, shape: Union[None, tuple, list] = None
|
591 |
+
) -> torch.Tensor:
|
592 |
+
# expect inds shape (b, s) with s = h*w
|
593 |
+
shape = default(shape, self.shape) # Optional[(h, w)]
|
594 |
+
if shape is not None:
|
595 |
+
assert len(shape) == 2, f"Unhandeled shape {shape}"
|
596 |
+
inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1])
|
597 |
+
h = self.regularization.get_codebook_entry(inds) # (b, h, w, c)
|
598 |
+
h = rearrange(h, "b h w c -> b c h w")
|
599 |
+
return self.model.decode(h)
|
600 |
+
|
601 |
+
|
602 |
+
class AutoencoderKLModeOnly(AutoencodingEngineLegacy):
|
603 |
+
def __init__(self, **kwargs):
|
604 |
+
if "lossconfig" in kwargs:
|
605 |
+
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
606 |
+
super().__init__(
|
607 |
+
regularizer_config={
|
608 |
+
"target": (
|
609 |
+
"sgm.modules.autoencoding.regularizers"
|
610 |
+
".DiagonalGaussianRegularizer"
|
611 |
+
),
|
612 |
+
"params": {"sample": False},
|
613 |
+
},
|
614 |
+
**kwargs,
|
615 |
+
)
|
sgm/models/diffusion.py
ADDED
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from contextlib import contextmanager
|
3 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
import torch
|
7 |
+
from omegaconf import ListConfig, OmegaConf
|
8 |
+
from safetensors.torch import load_file as load_safetensors
|
9 |
+
from torch.optim.lr_scheduler import LambdaLR
|
10 |
+
from einops import rearrange
|
11 |
+
|
12 |
+
from ..modules import UNCONDITIONAL_CONFIG
|
13 |
+
from ..modules.autoencoding.temporal_ae import VideoDecoder
|
14 |
+
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
|
15 |
+
from ..modules.ema import LitEma
|
16 |
+
from ..util import (
|
17 |
+
default,
|
18 |
+
disabled_train,
|
19 |
+
get_obj_from_str,
|
20 |
+
instantiate_from_config,
|
21 |
+
log_txt_as_img,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
class DiffusionEngine(pl.LightningModule):
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
network_config,
|
29 |
+
denoiser_config,
|
30 |
+
first_stage_config,
|
31 |
+
conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
32 |
+
sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
33 |
+
optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
34 |
+
scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
35 |
+
loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
36 |
+
network_wrapper: Union[None, str] = None,
|
37 |
+
ckpt_path: Union[None, str] = None,
|
38 |
+
use_ema: bool = False,
|
39 |
+
ema_decay_rate: float = 0.9999,
|
40 |
+
scale_factor: float = 1.0,
|
41 |
+
disable_first_stage_autocast=False,
|
42 |
+
input_key: str = "jpg",
|
43 |
+
log_keys: Union[List, None] = None,
|
44 |
+
no_cond_log: bool = False,
|
45 |
+
compile_model: bool = False,
|
46 |
+
en_and_decode_n_samples_a_time: Optional[int] = None,
|
47 |
+
):
|
48 |
+
super().__init__()
|
49 |
+
self.log_keys = log_keys
|
50 |
+
self.input_key = input_key
|
51 |
+
self.optimizer_config = default(
|
52 |
+
optimizer_config, {"target": "torch.optim.AdamW"}
|
53 |
+
)
|
54 |
+
model = instantiate_from_config(network_config)
|
55 |
+
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
|
56 |
+
model, compile_model=compile_model
|
57 |
+
)
|
58 |
+
|
59 |
+
self.denoiser = instantiate_from_config(denoiser_config)
|
60 |
+
self.sampler = (
|
61 |
+
instantiate_from_config(sampler_config)
|
62 |
+
if sampler_config is not None
|
63 |
+
else None
|
64 |
+
)
|
65 |
+
self.conditioner = instantiate_from_config(
|
66 |
+
default(conditioner_config, UNCONDITIONAL_CONFIG)
|
67 |
+
)
|
68 |
+
self.scheduler_config = scheduler_config
|
69 |
+
self._init_first_stage(first_stage_config)
|
70 |
+
|
71 |
+
self.loss_fn = (
|
72 |
+
instantiate_from_config(loss_fn_config)
|
73 |
+
if loss_fn_config is not None
|
74 |
+
else None
|
75 |
+
)
|
76 |
+
|
77 |
+
self.use_ema = use_ema
|
78 |
+
if self.use_ema:
|
79 |
+
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
|
80 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
81 |
+
|
82 |
+
self.scale_factor = scale_factor
|
83 |
+
self.disable_first_stage_autocast = disable_first_stage_autocast
|
84 |
+
self.no_cond_log = no_cond_log
|
85 |
+
|
86 |
+
if ckpt_path is not None:
|
87 |
+
self.init_from_ckpt(ckpt_path)
|
88 |
+
|
89 |
+
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
|
90 |
+
|
91 |
+
def init_from_ckpt(
|
92 |
+
self,
|
93 |
+
path: str,
|
94 |
+
) -> None:
|
95 |
+
if path.endswith("ckpt"):
|
96 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
97 |
+
elif path.endswith("safetensors"):
|
98 |
+
sd = load_safetensors(path)
|
99 |
+
else:
|
100 |
+
raise NotImplementedError
|
101 |
+
|
102 |
+
missing, unexpected = self.load_state_dict(sd, strict=False)
|
103 |
+
print(
|
104 |
+
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
105 |
+
)
|
106 |
+
if len(missing) > 0:
|
107 |
+
print(f"Missing Keys: {missing}")
|
108 |
+
if len(unexpected) > 0:
|
109 |
+
print(f"Unexpected Keys: {unexpected}")
|
110 |
+
|
111 |
+
def _init_first_stage(self, config):
|
112 |
+
model = instantiate_from_config(config).eval()
|
113 |
+
model.train = disabled_train
|
114 |
+
for param in model.parameters():
|
115 |
+
param.requires_grad = False
|
116 |
+
self.first_stage_model = model
|
117 |
+
|
118 |
+
def get_input(self, batch):
|
119 |
+
# assuming unified data format, dataloader returns a dict.
|
120 |
+
# image tensors should be scaled to -1 ... 1 and in bchw format
|
121 |
+
return batch[self.input_key]
|
122 |
+
|
123 |
+
@torch.no_grad()
|
124 |
+
def decode_first_stage(self, z):
|
125 |
+
z = 1.0 / self.scale_factor * z
|
126 |
+
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
|
127 |
+
|
128 |
+
n_rounds = math.ceil(z.shape[0] / n_samples)
|
129 |
+
all_out = []
|
130 |
+
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
131 |
+
for n in range(n_rounds):
|
132 |
+
if isinstance(self.first_stage_model.decoder, VideoDecoder):
|
133 |
+
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
|
134 |
+
else:
|
135 |
+
kwargs = {}
|
136 |
+
out = self.first_stage_model.decode(
|
137 |
+
z[n * n_samples : (n + 1) * n_samples], **kwargs
|
138 |
+
)
|
139 |
+
all_out.append(out)
|
140 |
+
out = torch.cat(all_out, dim=0)
|
141 |
+
return out
|
142 |
+
|
143 |
+
@torch.no_grad()
|
144 |
+
def encode_first_stage(self, x):
|
145 |
+
bs = x.shape[0]
|
146 |
+
is_video_input = False
|
147 |
+
if x.dim() == 5:
|
148 |
+
is_video_input = True
|
149 |
+
# for video diffusion
|
150 |
+
x = rearrange(x, "b t c h w -> (b t) c h w")
|
151 |
+
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
|
152 |
+
n_rounds = math.ceil(x.shape[0] / n_samples)
|
153 |
+
all_out = []
|
154 |
+
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
155 |
+
for n in range(n_rounds):
|
156 |
+
out = self.first_stage_model.encode(
|
157 |
+
x[n * n_samples : (n + 1) * n_samples]
|
158 |
+
)
|
159 |
+
all_out.append(out)
|
160 |
+
z = torch.cat(all_out, dim=0)
|
161 |
+
z = self.scale_factor * z
|
162 |
+
|
163 |
+
if is_video_input:
|
164 |
+
z = rearrange(z, "(b t) c h w -> b t c h w", b=bs)
|
165 |
+
|
166 |
+
return z
|
167 |
+
|
168 |
+
def forward(self, x, batch):
|
169 |
+
loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
|
170 |
+
loss_mean = loss.mean()
|
171 |
+
loss_dict = {"loss": loss_mean}
|
172 |
+
return loss_mean, loss_dict
|
173 |
+
|
174 |
+
def shared_step(self, batch: Dict) -> Any:
|
175 |
+
x = self.get_input(batch)
|
176 |
+
breakpoint()
|
177 |
+
x = self.encode_first_stage(x)
|
178 |
+
batch["global_step"] = self.global_step
|
179 |
+
loss, loss_dict = self(x, batch)
|
180 |
+
return loss, loss_dict
|
181 |
+
|
182 |
+
def training_step(self, batch, batch_idx):
|
183 |
+
loss, loss_dict = self.shared_step(batch)
|
184 |
+
|
185 |
+
self.log_dict(
|
186 |
+
loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
|
187 |
+
)
|
188 |
+
|
189 |
+
self.log(
|
190 |
+
"global_step",
|
191 |
+
self.global_step,
|
192 |
+
prog_bar=True,
|
193 |
+
logger=True,
|
194 |
+
on_step=True,
|
195 |
+
on_epoch=False,
|
196 |
+
)
|
197 |
+
|
198 |
+
if self.scheduler_config is not None:
|
199 |
+
lr = self.optimizers().param_groups[0]["lr"]
|
200 |
+
self.log(
|
201 |
+
"lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
|
202 |
+
)
|
203 |
+
|
204 |
+
return loss
|
205 |
+
|
206 |
+
def on_train_start(self, *args, **kwargs):
|
207 |
+
if self.sampler is None or self.loss_fn is None:
|
208 |
+
raise ValueError("Sampler and loss function need to be set for training.")
|
209 |
+
|
210 |
+
def on_train_batch_end(self, *args, **kwargs):
|
211 |
+
if self.use_ema:
|
212 |
+
self.model_ema(self.model)
|
213 |
+
|
214 |
+
@contextmanager
|
215 |
+
def ema_scope(self, context=None):
|
216 |
+
if self.use_ema:
|
217 |
+
self.model_ema.store(self.model.parameters())
|
218 |
+
self.model_ema.copy_to(self.model)
|
219 |
+
if context is not None:
|
220 |
+
print(f"{context}: Switched to EMA weights")
|
221 |
+
try:
|
222 |
+
yield None
|
223 |
+
finally:
|
224 |
+
if self.use_ema:
|
225 |
+
self.model_ema.restore(self.model.parameters())
|
226 |
+
if context is not None:
|
227 |
+
print(f"{context}: Restored training weights")
|
228 |
+
|
229 |
+
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
230 |
+
return get_obj_from_str(cfg["target"])(
|
231 |
+
params, lr=lr, **cfg.get("params", dict())
|
232 |
+
)
|
233 |
+
|
234 |
+
def configure_optimizers(self):
|
235 |
+
lr = self.learning_rate
|
236 |
+
params = list(self.model.parameters())
|
237 |
+
for embedder in self.conditioner.embedders:
|
238 |
+
if embedder.is_trainable:
|
239 |
+
params = params + list(embedder.parameters())
|
240 |
+
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
|
241 |
+
if self.scheduler_config is not None:
|
242 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
243 |
+
print("Setting up LambdaLR scheduler...")
|
244 |
+
scheduler = [
|
245 |
+
{
|
246 |
+
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
|
247 |
+
"interval": "step",
|
248 |
+
"frequency": 1,
|
249 |
+
}
|
250 |
+
]
|
251 |
+
return [opt], scheduler
|
252 |
+
return opt
|
253 |
+
|
254 |
+
@torch.no_grad()
|
255 |
+
def sample(
|
256 |
+
self,
|
257 |
+
cond: Dict,
|
258 |
+
uc: Union[Dict, None] = None,
|
259 |
+
batch_size: int = 16,
|
260 |
+
shape: Union[None, Tuple, List] = None,
|
261 |
+
**kwargs,
|
262 |
+
):
|
263 |
+
randn = torch.randn(batch_size, *shape).to(self.device)
|
264 |
+
|
265 |
+
denoiser = lambda input, sigma, c: self.denoiser(
|
266 |
+
self.model, input, sigma, c, **kwargs
|
267 |
+
)
|
268 |
+
samples = self.sampler(denoiser, randn, cond, uc=uc)
|
269 |
+
return samples
|
270 |
+
|
271 |
+
@torch.no_grad()
|
272 |
+
def log_conditionings(self, batch: Dict, n: int) -> Dict:
|
273 |
+
"""
|
274 |
+
Defines heuristics to log different conditionings.
|
275 |
+
These can be lists of strings (text-to-image), tensors, ints, ...
|
276 |
+
"""
|
277 |
+
image_h, image_w = batch[self.input_key].shape[2:]
|
278 |
+
log = dict()
|
279 |
+
|
280 |
+
for embedder in self.conditioner.embedders:
|
281 |
+
if (
|
282 |
+
(self.log_keys is None) or (embedder.input_key in self.log_keys)
|
283 |
+
) and not self.no_cond_log:
|
284 |
+
x = batch[embedder.input_key][:n]
|
285 |
+
if isinstance(x, torch.Tensor):
|
286 |
+
if x.dim() == 1:
|
287 |
+
# class-conditional, convert integer to string
|
288 |
+
x = [str(x[i].item()) for i in range(x.shape[0])]
|
289 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
|
290 |
+
elif x.dim() == 2:
|
291 |
+
# size and crop cond and the like
|
292 |
+
x = [
|
293 |
+
"x".join([str(xx) for xx in x[i].tolist()])
|
294 |
+
for i in range(x.shape[0])
|
295 |
+
]
|
296 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
297 |
+
else:
|
298 |
+
raise NotImplementedError()
|
299 |
+
elif isinstance(x, (List, ListConfig)):
|
300 |
+
if isinstance(x[0], str):
|
301 |
+
# strings
|
302 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
303 |
+
else:
|
304 |
+
raise NotImplementedError()
|
305 |
+
else:
|
306 |
+
raise NotImplementedError()
|
307 |
+
log[embedder.input_key] = xc
|
308 |
+
return log
|
309 |
+
|
310 |
+
@torch.no_grad()
|
311 |
+
def log_images(
|
312 |
+
self,
|
313 |
+
batch: Dict,
|
314 |
+
N: int = 8,
|
315 |
+
sample: bool = True,
|
316 |
+
ucg_keys: List[str] = None,
|
317 |
+
**kwargs,
|
318 |
+
) -> Dict:
|
319 |
+
conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
|
320 |
+
if ucg_keys:
|
321 |
+
assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
|
322 |
+
"Each defined ucg key for sampling must be in the provided conditioner input keys,"
|
323 |
+
f"but we have {ucg_keys} vs. {conditioner_input_keys}"
|
324 |
+
)
|
325 |
+
else:
|
326 |
+
ucg_keys = conditioner_input_keys
|
327 |
+
log = dict()
|
328 |
+
|
329 |
+
x = self.get_input(batch)
|
330 |
+
|
331 |
+
c, uc = self.conditioner.get_unconditional_conditioning(
|
332 |
+
batch,
|
333 |
+
force_uc_zero_embeddings=ucg_keys
|
334 |
+
if len(self.conditioner.embedders) > 0
|
335 |
+
else [],
|
336 |
+
)
|
337 |
+
|
338 |
+
sampling_kwargs = {}
|
339 |
+
|
340 |
+
N = min(x.shape[0], N)
|
341 |
+
x = x.to(self.device)[:N]
|
342 |
+
log["inputs"] = x
|
343 |
+
z = self.encode_first_stage(x)
|
344 |
+
log["reconstructions"] = self.decode_first_stage(z)
|
345 |
+
log.update(self.log_conditionings(batch, N))
|
346 |
+
|
347 |
+
for k in c:
|
348 |
+
if isinstance(c[k], torch.Tensor):
|
349 |
+
c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
|
350 |
+
|
351 |
+
if sample:
|
352 |
+
with self.ema_scope("Plotting"):
|
353 |
+
samples = self.sample(
|
354 |
+
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
|
355 |
+
)
|
356 |
+
samples = self.decode_first_stage(samples)
|
357 |
+
log["samples"] = samples
|
358 |
+
return log
|
sgm/models/video3d_diffusion.py
ADDED
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import math
|
3 |
+
from contextlib import contextmanager
|
4 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
from pytorch_lightning.loggers import WandbLogger
|
8 |
+
import torch
|
9 |
+
from omegaconf import ListConfig, OmegaConf
|
10 |
+
from safetensors.torch import load_file as load_safetensors
|
11 |
+
from torch.optim.lr_scheduler import LambdaLR
|
12 |
+
from torchvision.utils import make_grid
|
13 |
+
from einops import rearrange, repeat
|
14 |
+
|
15 |
+
from ..modules import UNCONDITIONAL_CONFIG
|
16 |
+
from ..modules.autoencoding.temporal_ae import VideoDecoder
|
17 |
+
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
|
18 |
+
from ..modules.ema import LitEma
|
19 |
+
from ..modules.encoders.modules import VideoPredictionEmbedderWithEncoder
|
20 |
+
from ..util import (
|
21 |
+
default,
|
22 |
+
disabled_train,
|
23 |
+
get_obj_from_str,
|
24 |
+
instantiate_from_config,
|
25 |
+
log_txt_as_img,
|
26 |
+
video_frames_as_grid,
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
def flatten_for_video(input):
|
31 |
+
return input.flatten()
|
32 |
+
|
33 |
+
|
34 |
+
class Video3DDiffusionEngine(pl.LightningModule):
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
network_config,
|
38 |
+
denoiser_config,
|
39 |
+
first_stage_config,
|
40 |
+
conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
41 |
+
sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
42 |
+
optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
43 |
+
scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
44 |
+
loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
45 |
+
network_wrapper: Union[None, str] = None,
|
46 |
+
ckpt_path: Union[None, str] = None,
|
47 |
+
use_ema: bool = False,
|
48 |
+
ema_decay_rate: float = 0.9999,
|
49 |
+
scale_factor: float = 1.0,
|
50 |
+
disable_first_stage_autocast=False,
|
51 |
+
input_key: str = "frames", # for video inputs
|
52 |
+
log_keys: Union[List, None] = None,
|
53 |
+
no_cond_log: bool = False,
|
54 |
+
compile_model: bool = False,
|
55 |
+
en_and_decode_n_samples_a_time: Optional[int] = None,
|
56 |
+
):
|
57 |
+
super().__init__()
|
58 |
+
self.log_keys = log_keys
|
59 |
+
self.input_key = input_key
|
60 |
+
self.optimizer_config = default(
|
61 |
+
optimizer_config, {"target": "torch.optim.AdamW"}
|
62 |
+
)
|
63 |
+
model = instantiate_from_config(network_config)
|
64 |
+
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
|
65 |
+
model, compile_model=compile_model
|
66 |
+
)
|
67 |
+
|
68 |
+
self.denoiser = instantiate_from_config(denoiser_config)
|
69 |
+
self.sampler = (
|
70 |
+
instantiate_from_config(sampler_config)
|
71 |
+
if sampler_config is not None
|
72 |
+
else None
|
73 |
+
)
|
74 |
+
self.conditioner = instantiate_from_config(
|
75 |
+
default(conditioner_config, UNCONDITIONAL_CONFIG)
|
76 |
+
)
|
77 |
+
self.scheduler_config = scheduler_config
|
78 |
+
self._init_first_stage(first_stage_config)
|
79 |
+
|
80 |
+
self.loss_fn = (
|
81 |
+
instantiate_from_config(loss_fn_config)
|
82 |
+
if loss_fn_config is not None
|
83 |
+
else None
|
84 |
+
)
|
85 |
+
|
86 |
+
self.use_ema = use_ema
|
87 |
+
if self.use_ema:
|
88 |
+
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
|
89 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
90 |
+
|
91 |
+
self.scale_factor = scale_factor
|
92 |
+
self.disable_first_stage_autocast = disable_first_stage_autocast
|
93 |
+
self.no_cond_log = no_cond_log
|
94 |
+
|
95 |
+
if ckpt_path is not None:
|
96 |
+
self.init_from_ckpt(ckpt_path)
|
97 |
+
|
98 |
+
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
|
99 |
+
|
100 |
+
def _load_last_embedder(self, original_state_dict):
|
101 |
+
original_module_name = "conditioner.embedders.3"
|
102 |
+
state_dict = dict()
|
103 |
+
for k, v in original_state_dict.items():
|
104 |
+
m = re.match(rf"^{original_module_name}\.(.*)$", k)
|
105 |
+
if m is None:
|
106 |
+
continue
|
107 |
+
state_dict[m.group(1)] = v
|
108 |
+
|
109 |
+
idx = -1
|
110 |
+
for i in range(len(self.conditioner.embedders)):
|
111 |
+
if isinstance(
|
112 |
+
self.conditioner.embedders[i], VideoPredictionEmbedderWithEncoder
|
113 |
+
):
|
114 |
+
idx = i
|
115 |
+
|
116 |
+
print(f"Embedder [{idx}] is the frame encoder, make sure this is expected")
|
117 |
+
|
118 |
+
self.conditioner.embedders[idx].load_state_dict(state_dict)
|
119 |
+
|
120 |
+
def init_from_ckpt(
|
121 |
+
self,
|
122 |
+
path: str,
|
123 |
+
) -> None:
|
124 |
+
if path.endswith("ckpt"):
|
125 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
126 |
+
elif path.endswith("safetensors"):
|
127 |
+
sd = load_safetensors(path)
|
128 |
+
else:
|
129 |
+
raise NotImplementedError
|
130 |
+
|
131 |
+
self_sd = self.state_dict()
|
132 |
+
input_keys = [
|
133 |
+
"model.diffusion_model.input_blocks.0.0.weight",
|
134 |
+
"model_ema.diffusion_modelinput_blocks00weight",
|
135 |
+
]
|
136 |
+
for input_key in input_keys:
|
137 |
+
if input_key not in sd or input_key not in self_sd:
|
138 |
+
continue
|
139 |
+
|
140 |
+
input_weight = self_sd[input_key]
|
141 |
+
|
142 |
+
if input_weight.shape != sd[input_key].shape:
|
143 |
+
print("Manual init: {}".format(input_key))
|
144 |
+
input_weight.zero_()
|
145 |
+
input_weight[:, :8, :, :].copy_(sd[input_key])
|
146 |
+
|
147 |
+
deleted_keys = []
|
148 |
+
for k, v in self.state_dict().items():
|
149 |
+
# resolve shape dismatch
|
150 |
+
if k in sd:
|
151 |
+
if v.shape != sd[k].shape:
|
152 |
+
del sd[k]
|
153 |
+
deleted_keys.append(k)
|
154 |
+
|
155 |
+
if len(deleted_keys) > 0:
|
156 |
+
print(f"Deleted Keys: {deleted_keys}")
|
157 |
+
|
158 |
+
missing, unexpected = self.load_state_dict(sd, strict=False)
|
159 |
+
print(
|
160 |
+
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
161 |
+
)
|
162 |
+
if len(missing) > 0:
|
163 |
+
print(f"Missing Keys: {missing}")
|
164 |
+
if len(unexpected) > 0:
|
165 |
+
print(f"Unexpected Keys: {unexpected}")
|
166 |
+
if len(deleted_keys) > 0:
|
167 |
+
print(f"Deleted Keys: {deleted_keys}")
|
168 |
+
|
169 |
+
if len(missing) > 0 or len(unexpected) > 0:
|
170 |
+
# means we are loading from a checkpoint that has the old embedder (motion bucket id and fps id)
|
171 |
+
print("Modified embedder to support 3d spiral video inputs")
|
172 |
+
try:
|
173 |
+
self._load_last_embedder(sd)
|
174 |
+
except:
|
175 |
+
print("Failed to load last embedder, make sure this is expected")
|
176 |
+
|
177 |
+
def _init_first_stage(self, config):
|
178 |
+
model = instantiate_from_config(config).eval()
|
179 |
+
model.train = disabled_train
|
180 |
+
for param in model.parameters():
|
181 |
+
param.requires_grad = False
|
182 |
+
self.first_stage_model = model
|
183 |
+
|
184 |
+
def get_input(self, batch):
|
185 |
+
# assuming unified data format, dataloader returns a dict.
|
186 |
+
# image tensors should be scaled to -1 ... 1 and in bchw format
|
187 |
+
return batch[self.input_key]
|
188 |
+
|
189 |
+
@torch.no_grad()
|
190 |
+
def decode_first_stage(self, z):
|
191 |
+
z = 1.0 / self.scale_factor * z
|
192 |
+
is_video_input = False
|
193 |
+
bs = z.shape[0]
|
194 |
+
if z.dim() == 5:
|
195 |
+
is_video_input = True
|
196 |
+
# for video diffusion
|
197 |
+
z = rearrange(z, "b t c h w -> (b t) c h w")
|
198 |
+
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
|
199 |
+
|
200 |
+
n_rounds = math.ceil(z.shape[0] / n_samples)
|
201 |
+
all_out = []
|
202 |
+
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
203 |
+
for n in range(n_rounds):
|
204 |
+
if isinstance(self.first_stage_model.decoder, VideoDecoder):
|
205 |
+
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
|
206 |
+
else:
|
207 |
+
kwargs = {}
|
208 |
+
out = self.first_stage_model.decode(
|
209 |
+
z[n * n_samples : (n + 1) * n_samples], **kwargs
|
210 |
+
)
|
211 |
+
all_out.append(out)
|
212 |
+
out = torch.cat(all_out, dim=0)
|
213 |
+
|
214 |
+
if is_video_input:
|
215 |
+
out = rearrange(out, "(b t) c h w -> b t c h w", b=bs)
|
216 |
+
|
217 |
+
return out
|
218 |
+
|
219 |
+
@torch.no_grad()
|
220 |
+
def encode_first_stage(self, x):
|
221 |
+
if self.input_key == "latents":
|
222 |
+
return x
|
223 |
+
|
224 |
+
bs = x.shape[0]
|
225 |
+
is_video_input = False
|
226 |
+
if x.dim() == 5:
|
227 |
+
is_video_input = True
|
228 |
+
# for video diffusion
|
229 |
+
x = rearrange(x, "b t c h w -> (b t) c h w")
|
230 |
+
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
|
231 |
+
n_rounds = math.ceil(x.shape[0] / n_samples)
|
232 |
+
all_out = []
|
233 |
+
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
234 |
+
for n in range(n_rounds):
|
235 |
+
out = self.first_stage_model.encode(
|
236 |
+
x[n * n_samples : (n + 1) * n_samples]
|
237 |
+
)
|
238 |
+
all_out.append(out)
|
239 |
+
z = torch.cat(all_out, dim=0)
|
240 |
+
z = self.scale_factor * z
|
241 |
+
|
242 |
+
# if is_video_input:
|
243 |
+
# z = rearrange(z, "(b t) c h w -> b t c h w", b=bs)
|
244 |
+
|
245 |
+
return z
|
246 |
+
|
247 |
+
def forward(self, x, batch):
|
248 |
+
loss, model_output = self.loss_fn(
|
249 |
+
self.model,
|
250 |
+
self.denoiser,
|
251 |
+
self.conditioner,
|
252 |
+
x,
|
253 |
+
batch,
|
254 |
+
return_model_output=True,
|
255 |
+
)
|
256 |
+
loss_mean = loss.mean()
|
257 |
+
loss_dict = {"loss": loss_mean, "model_output": model_output}
|
258 |
+
return loss_mean, loss_dict
|
259 |
+
|
260 |
+
def shared_step(self, batch: Dict) -> Any:
|
261 |
+
# TODO: move this shit to collate_fn in dataloader
|
262 |
+
# if "fps_id" in batch:
|
263 |
+
# batch["fps_id"] = flatten_for_video(batch["fps_id"])
|
264 |
+
# if "motion_bucket_id" in batch:
|
265 |
+
# batch["motion_bucket_id"] = flatten_for_video(batch["motion_bucket_id"])
|
266 |
+
# if "cond_aug" in batch:
|
267 |
+
# batch["cond_aug"] = flatten_for_video(batch["cond_aug"])
|
268 |
+
x = self.get_input(batch)
|
269 |
+
x = self.encode_first_stage(x)
|
270 |
+
# ## debug
|
271 |
+
# x_recon = self.decode_first_stage(x)
|
272 |
+
# video_frames_as_grid((batch["frames"][0] + 1.0) / 2.0, "./tmp/origin.jpg")
|
273 |
+
# video_frames_as_grid((x_recon[0] + 1.0) / 2.0, "./tmp/recon.jpg")
|
274 |
+
# ## debug
|
275 |
+
batch["global_step"] = self.global_step
|
276 |
+
loss, loss_dict = self(x, batch)
|
277 |
+
return loss, loss_dict
|
278 |
+
|
279 |
+
def training_step(self, batch, batch_idx):
|
280 |
+
loss, loss_dict = self.shared_step(batch)
|
281 |
+
|
282 |
+
with torch.no_grad():
|
283 |
+
if "model_output" in loss_dict:
|
284 |
+
if batch_idx % 100 == 0:
|
285 |
+
if isinstance(self.logger, WandbLogger):
|
286 |
+
model_output = loss_dict["model_output"].detach()[
|
287 |
+
: batch["num_video_frames"]
|
288 |
+
]
|
289 |
+
recons = (
|
290 |
+
(self.decode_first_stage(model_output) + 1.0) / 2.0
|
291 |
+
).clamp(0.0, 1.0)
|
292 |
+
recon_grid = make_grid(recons, nrow=4)
|
293 |
+
self.logger.log_image(
|
294 |
+
key=f"train/model_output_recon",
|
295 |
+
images=[recon_grid],
|
296 |
+
step=self.global_step,
|
297 |
+
)
|
298 |
+
del loss_dict["model_output"]
|
299 |
+
|
300 |
+
if torch.isnan(loss).any():
|
301 |
+
print("Nan detected")
|
302 |
+
loss = None
|
303 |
+
|
304 |
+
self.log_dict(
|
305 |
+
loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
|
306 |
+
)
|
307 |
+
|
308 |
+
self.log(
|
309 |
+
"global_step",
|
310 |
+
self.global_step,
|
311 |
+
prog_bar=True,
|
312 |
+
logger=True,
|
313 |
+
on_step=True,
|
314 |
+
on_epoch=False,
|
315 |
+
)
|
316 |
+
|
317 |
+
if self.scheduler_config is not None:
|
318 |
+
lr = self.optimizers().param_groups[0]["lr"]
|
319 |
+
self.log(
|
320 |
+
"lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
|
321 |
+
)
|
322 |
+
|
323 |
+
return loss
|
324 |
+
|
325 |
+
def on_train_start(self, *args, **kwargs):
|
326 |
+
if self.sampler is None or self.loss_fn is None:
|
327 |
+
raise ValueError("Sampler and loss function need to be set for training.")
|
328 |
+
|
329 |
+
def on_train_batch_end(self, *args, **kwargs):
|
330 |
+
if self.use_ema:
|
331 |
+
self.model_ema(self.model)
|
332 |
+
|
333 |
+
@contextmanager
|
334 |
+
def ema_scope(self, context=None):
|
335 |
+
if self.use_ema:
|
336 |
+
self.model_ema.store(self.model.parameters())
|
337 |
+
self.model_ema.copy_to(self.model)
|
338 |
+
if context is not None:
|
339 |
+
print(f"{context}: Switched to EMA weights")
|
340 |
+
try:
|
341 |
+
yield None
|
342 |
+
finally:
|
343 |
+
if self.use_ema:
|
344 |
+
self.model_ema.restore(self.model.parameters())
|
345 |
+
if context is not None:
|
346 |
+
print(f"{context}: Restored training weights")
|
347 |
+
|
348 |
+
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
349 |
+
return get_obj_from_str(cfg["target"])(
|
350 |
+
params, lr=lr, **cfg.get("params", dict())
|
351 |
+
)
|
352 |
+
|
353 |
+
def configure_optimizers(self):
|
354 |
+
lr = self.learning_rate
|
355 |
+
params = list(self.model.parameters())
|
356 |
+
for embedder in self.conditioner.embedders:
|
357 |
+
if embedder.is_trainable:
|
358 |
+
params = params + list(embedder.parameters())
|
359 |
+
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
|
360 |
+
if self.scheduler_config is not None:
|
361 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
362 |
+
print("Setting up LambdaLR scheduler...")
|
363 |
+
scheduler = [
|
364 |
+
{
|
365 |
+
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
|
366 |
+
"interval": "step",
|
367 |
+
"frequency": 1,
|
368 |
+
}
|
369 |
+
]
|
370 |
+
return [opt], scheduler
|
371 |
+
return opt
|
372 |
+
|
373 |
+
@torch.no_grad()
|
374 |
+
def sample(
|
375 |
+
self,
|
376 |
+
cond: Dict,
|
377 |
+
uc: Union[Dict, None] = None,
|
378 |
+
batch_size: int = 16,
|
379 |
+
shape: Union[None, Tuple, List] = None,
|
380 |
+
**kwargs,
|
381 |
+
):
|
382 |
+
randn = torch.randn(batch_size, *shape).to(self.device)
|
383 |
+
|
384 |
+
denoiser = lambda input, sigma, c: self.denoiser(
|
385 |
+
self.model, input, sigma, c, **kwargs
|
386 |
+
)
|
387 |
+
samples = self.sampler(denoiser, randn, cond, uc=uc)
|
388 |
+
return samples
|
389 |
+
|
390 |
+
@torch.no_grad()
|
391 |
+
def log_conditionings(self, batch: Dict, n: int) -> Dict:
|
392 |
+
"""
|
393 |
+
Defines heuristics to log different conditionings.
|
394 |
+
These can be lists of strings (text-to-image), tensors, ints, ...
|
395 |
+
"""
|
396 |
+
image_h, image_w = batch[self.input_key].shape[-2:]
|
397 |
+
log = dict()
|
398 |
+
|
399 |
+
for embedder in self.conditioner.embedders:
|
400 |
+
if (
|
401 |
+
(self.log_keys is None) or (embedder.input_key in self.log_keys)
|
402 |
+
) and not self.no_cond_log:
|
403 |
+
x = batch[embedder.input_key][:n]
|
404 |
+
if isinstance(x, torch.Tensor):
|
405 |
+
if x.dim() == 1:
|
406 |
+
# class-conditional, convert integer to string
|
407 |
+
x = [str(x[i].item()) for i in range(x.shape[0])]
|
408 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
|
409 |
+
elif x.dim() == 2:
|
410 |
+
# size and crop cond and the like
|
411 |
+
x = [
|
412 |
+
"x".join([str(xx) for xx in x[i].tolist()])
|
413 |
+
for i in range(x.shape[0])
|
414 |
+
]
|
415 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
416 |
+
elif x.dim() == 4:
|
417 |
+
# image
|
418 |
+
xc = x
|
419 |
+
else:
|
420 |
+
raise NotImplementedError()
|
421 |
+
elif isinstance(x, (List, ListConfig)):
|
422 |
+
if isinstance(x[0], str):
|
423 |
+
# strings
|
424 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
425 |
+
else:
|
426 |
+
raise NotImplementedError()
|
427 |
+
else:
|
428 |
+
raise NotImplementedError()
|
429 |
+
log[embedder.input_key] = xc
|
430 |
+
|
431 |
+
return log
|
432 |
+
|
433 |
+
# for video diffusions will be logging frames of a video
|
434 |
+
@torch.no_grad()
|
435 |
+
def log_images(
|
436 |
+
self,
|
437 |
+
batch: Dict,
|
438 |
+
N: int = 1,
|
439 |
+
sample: bool = True,
|
440 |
+
ucg_keys: List[str] = None,
|
441 |
+
**kwargs,
|
442 |
+
) -> Dict:
|
443 |
+
# # debug
|
444 |
+
# return {}
|
445 |
+
# # debug
|
446 |
+
assert "num_video_frames" in batch, "num_video_frames must be in batch"
|
447 |
+
num_video_frames = batch["num_video_frames"]
|
448 |
+
conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
|
449 |
+
conditioner_input_keys = []
|
450 |
+
for e in self.conditioner.embedders:
|
451 |
+
if e.input_key is not None:
|
452 |
+
conditioner_input_keys.append(e.input_key)
|
453 |
+
else:
|
454 |
+
conditioner_input_keys.extend(e.input_keys)
|
455 |
+
if ucg_keys:
|
456 |
+
assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
|
457 |
+
"Each defined ucg key for sampling must be in the provided conditioner input keys,"
|
458 |
+
f"but we have {ucg_keys} vs. {conditioner_input_keys}"
|
459 |
+
)
|
460 |
+
else:
|
461 |
+
ucg_keys = conditioner_input_keys
|
462 |
+
log = dict()
|
463 |
+
|
464 |
+
x = self.get_input(batch)
|
465 |
+
|
466 |
+
c, uc = self.conditioner.get_unconditional_conditioning(
|
467 |
+
batch,
|
468 |
+
force_uc_zero_embeddings=ucg_keys
|
469 |
+
if len(self.conditioner.embedders) > 0
|
470 |
+
else [],
|
471 |
+
)
|
472 |
+
|
473 |
+
sampling_kwargs = {"num_video_frames": num_video_frames}
|
474 |
+
n = min(x.shape[0] // num_video_frames, N)
|
475 |
+
sampling_kwargs["image_only_indicator"] = torch.cat(
|
476 |
+
[batch["image_only_indicator"][:n]] * 2
|
477 |
+
)
|
478 |
+
|
479 |
+
N = min(x.shape[0] // num_video_frames, N) * num_video_frames
|
480 |
+
x = x.to(self.device)[:N]
|
481 |
+
# log["inputs"] = rearrange(x, "(b t) c h w -> b c h (t w)", t=num_video_frames)
|
482 |
+
log["inputs"] = x
|
483 |
+
z = self.encode_first_stage(x)
|
484 |
+
recon = self.decode_first_stage(z)
|
485 |
+
# log["reconstructions"] = rearrange(
|
486 |
+
# recon, "(b t) c h w -> b c h (t w)", t=num_video_frames
|
487 |
+
# )
|
488 |
+
log["reconstructions"] = recon
|
489 |
+
log.update(self.log_conditionings(batch, N))
|
490 |
+
log["pixelnerf_rgb"] = c["rgb"]
|
491 |
+
|
492 |
+
for k in ["crossattn", "concat", "vector"]:
|
493 |
+
if k in c:
|
494 |
+
c[k] = c[k][:N]
|
495 |
+
uc[k] = uc[k][:N]
|
496 |
+
|
497 |
+
# for k in c:
|
498 |
+
# if isinstance(c[k], torch.Tensor):
|
499 |
+
# if k == "vector":
|
500 |
+
# end = N
|
501 |
+
# else:
|
502 |
+
# end = n
|
503 |
+
# c[k], uc[k] = map(lambda y: y[k][:end].to(self.device), (c, uc))
|
504 |
+
|
505 |
+
# # for k in c:
|
506 |
+
# # print(c[k].shape)
|
507 |
+
|
508 |
+
# breakpoint()
|
509 |
+
# for k in ["crossattn", "concat"]:
|
510 |
+
# c[k] = repeat(c[k], "b ... -> b t ...", t=num_video_frames)
|
511 |
+
# c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_video_frames)
|
512 |
+
# uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_video_frames)
|
513 |
+
# uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_video_frames)
|
514 |
+
|
515 |
+
# for k in c:
|
516 |
+
# print(c[k].shape)
|
517 |
+
if sample:
|
518 |
+
with self.ema_scope("Plotting"):
|
519 |
+
samples = self.sample(
|
520 |
+
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
|
521 |
+
)
|
522 |
+
samples = self.decode_first_stage(samples)
|
523 |
+
log["samples"] = samples
|
524 |
+
return log
|
sgm/models/video_diffusion.py
ADDED
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import math
|
3 |
+
from contextlib import contextmanager
|
4 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
from pytorch_lightning.loggers import WandbLogger
|
8 |
+
import torch
|
9 |
+
from omegaconf import ListConfig, OmegaConf
|
10 |
+
from safetensors.torch import load_file as load_safetensors
|
11 |
+
from torch.optim.lr_scheduler import LambdaLR
|
12 |
+
from torchvision.utils import make_grid
|
13 |
+
from einops import rearrange, repeat
|
14 |
+
|
15 |
+
from ..modules import UNCONDITIONAL_CONFIG
|
16 |
+
from ..modules.autoencoding.temporal_ae import VideoDecoder
|
17 |
+
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
|
18 |
+
from ..modules.ema import LitEma
|
19 |
+
from ..modules.encoders.modules import VideoPredictionEmbedderWithEncoder
|
20 |
+
from ..util import (
|
21 |
+
default,
|
22 |
+
disabled_train,
|
23 |
+
get_obj_from_str,
|
24 |
+
instantiate_from_config,
|
25 |
+
log_txt_as_img,
|
26 |
+
video_frames_as_grid,
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
def flatten_for_video(input):
|
31 |
+
return input.flatten()
|
32 |
+
|
33 |
+
|
34 |
+
class DiffusionEngine(pl.LightningModule):
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
network_config,
|
38 |
+
denoiser_config,
|
39 |
+
first_stage_config,
|
40 |
+
conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
41 |
+
sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
42 |
+
optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
43 |
+
scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
44 |
+
loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
45 |
+
network_wrapper: Union[None, str] = None,
|
46 |
+
ckpt_path: Union[None, str] = None,
|
47 |
+
use_ema: bool = False,
|
48 |
+
ema_decay_rate: float = 0.9999,
|
49 |
+
scale_factor: float = 1.0,
|
50 |
+
disable_first_stage_autocast=False,
|
51 |
+
input_key: str = "frames", # for video inputs
|
52 |
+
log_keys: Union[List, None] = None,
|
53 |
+
no_cond_log: bool = False,
|
54 |
+
compile_model: bool = False,
|
55 |
+
en_and_decode_n_samples_a_time: Optional[int] = None,
|
56 |
+
load_last_embedder: bool = False,
|
57 |
+
from_scratch: bool = False,
|
58 |
+
):
|
59 |
+
super().__init__()
|
60 |
+
self.log_keys = log_keys
|
61 |
+
self.input_key = input_key
|
62 |
+
self.optimizer_config = default(
|
63 |
+
optimizer_config, {"target": "torch.optim.AdamW"}
|
64 |
+
)
|
65 |
+
model = instantiate_from_config(network_config)
|
66 |
+
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
|
67 |
+
model, compile_model=compile_model
|
68 |
+
)
|
69 |
+
|
70 |
+
self.denoiser = instantiate_from_config(denoiser_config)
|
71 |
+
self.sampler = (
|
72 |
+
instantiate_from_config(sampler_config)
|
73 |
+
if sampler_config is not None
|
74 |
+
else None
|
75 |
+
)
|
76 |
+
self.conditioner = instantiate_from_config(
|
77 |
+
default(conditioner_config, UNCONDITIONAL_CONFIG)
|
78 |
+
)
|
79 |
+
self.scheduler_config = scheduler_config
|
80 |
+
self._init_first_stage(first_stage_config)
|
81 |
+
|
82 |
+
self.loss_fn = (
|
83 |
+
instantiate_from_config(loss_fn_config)
|
84 |
+
if loss_fn_config is not None
|
85 |
+
else None
|
86 |
+
)
|
87 |
+
|
88 |
+
self.use_ema = use_ema
|
89 |
+
if self.use_ema:
|
90 |
+
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
|
91 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
92 |
+
|
93 |
+
self.scale_factor = scale_factor
|
94 |
+
self.disable_first_stage_autocast = disable_first_stage_autocast
|
95 |
+
self.no_cond_log = no_cond_log
|
96 |
+
|
97 |
+
self.load_last_embedder = load_last_embedder
|
98 |
+
if ckpt_path is not None:
|
99 |
+
self.init_from_ckpt(ckpt_path, from_scratch)
|
100 |
+
|
101 |
+
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
|
102 |
+
|
103 |
+
def _load_last_embedder(self, original_state_dict):
|
104 |
+
original_module_name = "conditioner.embedders.3"
|
105 |
+
state_dict = dict()
|
106 |
+
for k, v in original_state_dict.items():
|
107 |
+
m = re.match(rf"^{original_module_name}\.(.*)$", k)
|
108 |
+
if m is None:
|
109 |
+
continue
|
110 |
+
state_dict[m.group(1)] = v
|
111 |
+
|
112 |
+
idx = -1
|
113 |
+
for i in range(len(self.conditioner.embedders)):
|
114 |
+
if isinstance(
|
115 |
+
self.conditioner.embedders[i], VideoPredictionEmbedderWithEncoder
|
116 |
+
):
|
117 |
+
idx = i
|
118 |
+
|
119 |
+
print(f"Embedder [{idx}] is the frame encoder, make sure this is expected")
|
120 |
+
|
121 |
+
self.conditioner.embedders[idx].load_state_dict(state_dict)
|
122 |
+
|
123 |
+
def init_from_ckpt(
|
124 |
+
self,
|
125 |
+
path: str,
|
126 |
+
from_scratch: bool = False,
|
127 |
+
) -> None:
|
128 |
+
if path.endswith("ckpt"):
|
129 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
130 |
+
elif path.endswith("safetensors"):
|
131 |
+
sd = load_safetensors(path)
|
132 |
+
else:
|
133 |
+
raise NotImplementedError
|
134 |
+
|
135 |
+
deleted_keys = []
|
136 |
+
for k, v in self.state_dict().items():
|
137 |
+
# resolve shape dismatch
|
138 |
+
if k in sd:
|
139 |
+
if v.shape != sd[k].shape:
|
140 |
+
del sd[k]
|
141 |
+
deleted_keys.append(k)
|
142 |
+
|
143 |
+
if from_scratch:
|
144 |
+
new_sd = {}
|
145 |
+
for k in sd:
|
146 |
+
if "first_stage_model" in k:
|
147 |
+
new_sd[k] = sd[k]
|
148 |
+
sd = new_sd
|
149 |
+
print(sd.keys())
|
150 |
+
|
151 |
+
if len(deleted_keys) > 0:
|
152 |
+
print(f"Deleted Keys: {deleted_keys}")
|
153 |
+
|
154 |
+
missing, unexpected = self.load_state_dict(sd, strict=False)
|
155 |
+
print(
|
156 |
+
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
157 |
+
)
|
158 |
+
if len(missing) > 0:
|
159 |
+
print(f"Missing Keys: {missing}")
|
160 |
+
if len(unexpected) > 0:
|
161 |
+
print(f"Unexpected Keys: {unexpected}")
|
162 |
+
if len(deleted_keys) > 0:
|
163 |
+
print(f"Deleted Keys: {deleted_keys}")
|
164 |
+
|
165 |
+
if (len(missing) > 0 or len(unexpected) > 0) and self.load_last_embedder:
|
166 |
+
# means we are loading from a checkpoint that has the old embedder (motion bucket id and fps id)
|
167 |
+
print("Modified embedder to support 3d spiral video inputs")
|
168 |
+
self._load_last_embedder(sd)
|
169 |
+
|
170 |
+
def _init_first_stage(self, config):
|
171 |
+
model = instantiate_from_config(config).eval()
|
172 |
+
model.train = disabled_train
|
173 |
+
for param in model.parameters():
|
174 |
+
param.requires_grad = False
|
175 |
+
self.first_stage_model = model
|
176 |
+
|
177 |
+
def get_input(self, batch):
|
178 |
+
# assuming unified data format, dataloader returns a dict.
|
179 |
+
# image tensors should be scaled to -1 ... 1 and in bchw format
|
180 |
+
return batch[self.input_key]
|
181 |
+
|
182 |
+
@torch.no_grad()
|
183 |
+
def decode_first_stage(self, z):
|
184 |
+
z = 1.0 / self.scale_factor * z
|
185 |
+
is_video_input = False
|
186 |
+
bs = z.shape[0]
|
187 |
+
if z.dim() == 5:
|
188 |
+
is_video_input = True
|
189 |
+
# for video diffusion
|
190 |
+
z = rearrange(z, "b t c h w -> (b t) c h w")
|
191 |
+
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
|
192 |
+
|
193 |
+
n_rounds = math.ceil(z.shape[0] / n_samples)
|
194 |
+
all_out = []
|
195 |
+
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
196 |
+
for n in range(n_rounds):
|
197 |
+
if isinstance(self.first_stage_model.decoder, VideoDecoder):
|
198 |
+
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
|
199 |
+
else:
|
200 |
+
kwargs = {}
|
201 |
+
out = self.first_stage_model.decode(
|
202 |
+
z[n * n_samples : (n + 1) * n_samples], **kwargs
|
203 |
+
)
|
204 |
+
all_out.append(out)
|
205 |
+
out = torch.cat(all_out, dim=0)
|
206 |
+
|
207 |
+
if is_video_input:
|
208 |
+
out = rearrange(out, "(b t) c h w -> b t c h w", b=bs)
|
209 |
+
|
210 |
+
return out
|
211 |
+
|
212 |
+
@torch.no_grad()
|
213 |
+
def encode_first_stage(self, x):
|
214 |
+
if self.input_key == "latents":
|
215 |
+
return x * self.scale_factor
|
216 |
+
|
217 |
+
bs = x.shape[0]
|
218 |
+
is_video_input = False
|
219 |
+
if x.dim() == 5:
|
220 |
+
is_video_input = True
|
221 |
+
# for video diffusion
|
222 |
+
x = rearrange(x, "b t c h w -> (b t) c h w")
|
223 |
+
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
|
224 |
+
n_rounds = math.ceil(x.shape[0] / n_samples)
|
225 |
+
all_out = []
|
226 |
+
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
227 |
+
for n in range(n_rounds):
|
228 |
+
out = self.first_stage_model.encode(
|
229 |
+
x[n * n_samples : (n + 1) * n_samples]
|
230 |
+
)
|
231 |
+
all_out.append(out)
|
232 |
+
z = torch.cat(all_out, dim=0)
|
233 |
+
z = self.scale_factor * z
|
234 |
+
|
235 |
+
# if is_video_input:
|
236 |
+
# z = rearrange(z, "(b t) c h w -> b t c h w", b=bs)
|
237 |
+
|
238 |
+
return z
|
239 |
+
|
240 |
+
def forward(self, x, batch):
|
241 |
+
loss, model_output = self.loss_fn(
|
242 |
+
self.model,
|
243 |
+
self.denoiser,
|
244 |
+
self.conditioner,
|
245 |
+
x,
|
246 |
+
batch,
|
247 |
+
return_model_output=True,
|
248 |
+
)
|
249 |
+
loss_mean = loss.mean()
|
250 |
+
loss_dict = {"loss": loss_mean, "model_output": model_output}
|
251 |
+
return loss_mean, loss_dict
|
252 |
+
|
253 |
+
def shared_step(self, batch: Dict) -> Any:
|
254 |
+
# TODO: move this shit to collate_fn in dataloader
|
255 |
+
# if "fps_id" in batch:
|
256 |
+
# batch["fps_id"] = flatten_for_video(batch["fps_id"])
|
257 |
+
# if "motion_bucket_id" in batch:
|
258 |
+
# batch["motion_bucket_id"] = flatten_for_video(batch["motion_bucket_id"])
|
259 |
+
# if "cond_aug" in batch:
|
260 |
+
# batch["cond_aug"] = flatten_for_video(batch["cond_aug"])
|
261 |
+
x = self.get_input(batch)
|
262 |
+
x = self.encode_first_stage(x)
|
263 |
+
# ## debug
|
264 |
+
# x_recon = self.decode_first_stage(x)
|
265 |
+
# video_frames_as_grid((batch["frames"][0] + 1.0) / 2.0, "./tmp/origin.jpg")
|
266 |
+
# video_frames_as_grid((x_recon[0] + 1.0) / 2.0, "./tmp/recon.jpg")
|
267 |
+
# ## debug
|
268 |
+
batch["global_step"] = self.global_step
|
269 |
+
# breakpoint()
|
270 |
+
loss, loss_dict = self(x, batch)
|
271 |
+
return loss, loss_dict
|
272 |
+
|
273 |
+
def training_step(self, batch, batch_idx):
|
274 |
+
loss, loss_dict = self.shared_step(batch)
|
275 |
+
|
276 |
+
with torch.no_grad():
|
277 |
+
if "model_output" in loss_dict:
|
278 |
+
if batch_idx % 100 == 0:
|
279 |
+
if isinstance(self.logger, WandbLogger):
|
280 |
+
model_output = loss_dict["model_output"].detach()[
|
281 |
+
: batch["num_video_frames"]
|
282 |
+
]
|
283 |
+
recons = (
|
284 |
+
(self.decode_first_stage(model_output) + 1.0) / 2.0
|
285 |
+
).clamp(0.0, 1.0)
|
286 |
+
recon_grid = make_grid(recons, nrow=4)
|
287 |
+
self.logger.log_image(
|
288 |
+
key=f"train/model_output_recon",
|
289 |
+
images=[recon_grid],
|
290 |
+
step=self.global_step,
|
291 |
+
)
|
292 |
+
del loss_dict["model_output"]
|
293 |
+
|
294 |
+
self.log_dict(
|
295 |
+
loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
|
296 |
+
)
|
297 |
+
|
298 |
+
self.log(
|
299 |
+
"global_step",
|
300 |
+
self.global_step,
|
301 |
+
prog_bar=True,
|
302 |
+
logger=True,
|
303 |
+
on_step=True,
|
304 |
+
on_epoch=False,
|
305 |
+
)
|
306 |
+
|
307 |
+
if self.scheduler_config is not None:
|
308 |
+
lr = self.optimizers().param_groups[0]["lr"]
|
309 |
+
self.log(
|
310 |
+
"lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
|
311 |
+
)
|
312 |
+
|
313 |
+
return loss
|
314 |
+
|
315 |
+
def on_train_start(self, *args, **kwargs):
|
316 |
+
if self.sampler is None or self.loss_fn is None:
|
317 |
+
raise ValueError("Sampler and loss function need to be set for training.")
|
318 |
+
|
319 |
+
def on_train_batch_end(self, *args, **kwargs):
|
320 |
+
if self.use_ema:
|
321 |
+
self.model_ema(self.model)
|
322 |
+
|
323 |
+
@contextmanager
|
324 |
+
def ema_scope(self, context=None):
|
325 |
+
if self.use_ema:
|
326 |
+
self.model_ema.store(self.model.parameters())
|
327 |
+
self.model_ema.copy_to(self.model)
|
328 |
+
if context is not None:
|
329 |
+
print(f"{context}: Switched to EMA weights")
|
330 |
+
try:
|
331 |
+
yield None
|
332 |
+
finally:
|
333 |
+
if self.use_ema:
|
334 |
+
self.model_ema.restore(self.model.parameters())
|
335 |
+
if context is not None:
|
336 |
+
print(f"{context}: Restored training weights")
|
337 |
+
|
338 |
+
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
339 |
+
return get_obj_from_str(cfg["target"])(
|
340 |
+
params, lr=lr, **cfg.get("params", dict())
|
341 |
+
)
|
342 |
+
|
343 |
+
def configure_optimizers(self):
|
344 |
+
lr = self.learning_rate
|
345 |
+
params = list(self.model.parameters())
|
346 |
+
for embedder in self.conditioner.embedders:
|
347 |
+
if embedder.is_trainable:
|
348 |
+
params = params + list(embedder.parameters())
|
349 |
+
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
|
350 |
+
if self.scheduler_config is not None:
|
351 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
352 |
+
print("Setting up LambdaLR scheduler...")
|
353 |
+
scheduler = [
|
354 |
+
{
|
355 |
+
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
|
356 |
+
"interval": "step",
|
357 |
+
"frequency": 1,
|
358 |
+
}
|
359 |
+
]
|
360 |
+
return [opt], scheduler
|
361 |
+
return opt
|
362 |
+
|
363 |
+
@torch.no_grad()
|
364 |
+
def sample(
|
365 |
+
self,
|
366 |
+
cond: Dict,
|
367 |
+
uc: Union[Dict, None] = None,
|
368 |
+
batch_size: int = 16,
|
369 |
+
shape: Union[None, Tuple, List] = None,
|
370 |
+
**kwargs,
|
371 |
+
):
|
372 |
+
randn = torch.randn(batch_size, *shape).to(self.device)
|
373 |
+
|
374 |
+
denoiser = lambda input, sigma, c: self.denoiser(
|
375 |
+
self.model, input, sigma, c, **kwargs
|
376 |
+
)
|
377 |
+
samples = self.sampler(denoiser, randn, cond, uc=uc)
|
378 |
+
return samples
|
379 |
+
|
380 |
+
@torch.no_grad()
|
381 |
+
def log_conditionings(self, batch: Dict, n: int) -> Dict:
|
382 |
+
"""
|
383 |
+
Defines heuristics to log different conditionings.
|
384 |
+
These can be lists of strings (text-to-image), tensors, ints, ...
|
385 |
+
"""
|
386 |
+
image_h, image_w = batch[self.input_key].shape[-2:]
|
387 |
+
log = dict()
|
388 |
+
|
389 |
+
for embedder in self.conditioner.embedders:
|
390 |
+
if (
|
391 |
+
(self.log_keys is None) or (embedder.input_key in self.log_keys)
|
392 |
+
) and not self.no_cond_log:
|
393 |
+
x = batch[embedder.input_key][:n]
|
394 |
+
if isinstance(x, torch.Tensor):
|
395 |
+
if x.dim() == 1:
|
396 |
+
# class-conditional, convert integer to string
|
397 |
+
x = [str(x[i].item()) for i in range(x.shape[0])]
|
398 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
|
399 |
+
elif x.dim() == 2:
|
400 |
+
# size and crop cond and the like
|
401 |
+
x = [
|
402 |
+
"x".join([str(xx) for xx in x[i].tolist()])
|
403 |
+
for i in range(x.shape[0])
|
404 |
+
]
|
405 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
406 |
+
elif x.dim() == 4:
|
407 |
+
# image
|
408 |
+
xc = x
|
409 |
+
else:
|
410 |
+
pass
|
411 |
+
# breakpoint()
|
412 |
+
# raise NotImplementedError()
|
413 |
+
elif isinstance(x, (List, ListConfig)):
|
414 |
+
if isinstance(x[0], str):
|
415 |
+
# strings
|
416 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
417 |
+
else:
|
418 |
+
raise NotImplementedError()
|
419 |
+
else:
|
420 |
+
raise NotImplementedError()
|
421 |
+
log[embedder.input_key] = xc
|
422 |
+
return log
|
423 |
+
|
424 |
+
# for video diffusions will be logging frames of a video
|
425 |
+
@torch.no_grad()
|
426 |
+
def log_images(
|
427 |
+
self,
|
428 |
+
batch: Dict,
|
429 |
+
N: int = 1,
|
430 |
+
sample: bool = True,
|
431 |
+
ucg_keys: List[str] = None,
|
432 |
+
**kwargs,
|
433 |
+
) -> Dict:
|
434 |
+
# # debug
|
435 |
+
# return {}
|
436 |
+
# # debug
|
437 |
+
assert "num_video_frames" in batch, "num_video_frames must be in batch"
|
438 |
+
num_video_frames = batch["num_video_frames"]
|
439 |
+
conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
|
440 |
+
if ucg_keys:
|
441 |
+
assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
|
442 |
+
"Each defined ucg key for sampling must be in the provided conditioner input keys,"
|
443 |
+
f"but we have {ucg_keys} vs. {conditioner_input_keys}"
|
444 |
+
)
|
445 |
+
else:
|
446 |
+
ucg_keys = conditioner_input_keys
|
447 |
+
log = dict()
|
448 |
+
|
449 |
+
x = self.get_input(batch)
|
450 |
+
|
451 |
+
c, uc = self.conditioner.get_unconditional_conditioning(
|
452 |
+
batch,
|
453 |
+
force_uc_zero_embeddings=ucg_keys
|
454 |
+
if len(self.conditioner.embedders) > 0
|
455 |
+
else [],
|
456 |
+
)
|
457 |
+
|
458 |
+
sampling_kwargs = {"num_video_frames": num_video_frames}
|
459 |
+
n = min(x.shape[0] // num_video_frames, N)
|
460 |
+
sampling_kwargs["image_only_indicator"] = torch.cat(
|
461 |
+
[batch["image_only_indicator"][:n]] * 2
|
462 |
+
)
|
463 |
+
|
464 |
+
N = min(x.shape[0] // num_video_frames, N) * num_video_frames
|
465 |
+
x = x.to(self.device)[:N]
|
466 |
+
# log["inputs"] = rearrange(x, "(b t) c h w -> b c h (t w)", t=num_video_frames)
|
467 |
+
if self.input_key != "latents":
|
468 |
+
log["inputs"] = x
|
469 |
+
z = self.encode_first_stage(x)
|
470 |
+
recon = self.decode_first_stage(z)
|
471 |
+
# log["reconstructions"] = rearrange(
|
472 |
+
# recon, "(b t) c h w -> b c h (t w)", t=num_video_frames
|
473 |
+
# )
|
474 |
+
log["reconstructions"] = recon
|
475 |
+
log.update(self.log_conditionings(batch, N))
|
476 |
+
|
477 |
+
for k in c:
|
478 |
+
if isinstance(c[k], torch.Tensor):
|
479 |
+
if k == "vector":
|
480 |
+
end = N
|
481 |
+
else:
|
482 |
+
end = n
|
483 |
+
c[k], uc[k] = map(lambda y: y[k][:end].to(self.device), (c, uc))
|
484 |
+
|
485 |
+
# for k in c:
|
486 |
+
# print(c[k].shape)
|
487 |
+
|
488 |
+
for k in ["crossattn", "concat"]:
|
489 |
+
c[k] = repeat(c[k], "b ... -> b t ...", t=num_video_frames)
|
490 |
+
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_video_frames)
|
491 |
+
uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_video_frames)
|
492 |
+
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_video_frames)
|
493 |
+
|
494 |
+
# for k in c:
|
495 |
+
# print(c[k].shape)
|
496 |
+
if sample:
|
497 |
+
with self.ema_scope("Plotting"):
|
498 |
+
samples = self.sample(
|
499 |
+
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
|
500 |
+
)
|
501 |
+
samples = self.decode_first_stage(samples)
|
502 |
+
log["samples"] = samples
|
503 |
+
return log
|
sgm/modules/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .encoders.modules import GeneralConditioner, ExtraConditioner
|
2 |
+
|
3 |
+
UNCONDITIONAL_CONFIG = {
|
4 |
+
"target": "sgm.modules.GeneralConditioner",
|
5 |
+
"params": {"emb_models": []},
|
6 |
+
}
|
sgm/modules/attention.py
ADDED
@@ -0,0 +1,764 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
from inspect import isfunction
|
4 |
+
from typing import Any, Optional
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
from packaging import version
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
# from torch.utils.checkpoint import checkpoint
|
14 |
+
|
15 |
+
checkpoint = partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
16 |
+
|
17 |
+
|
18 |
+
logpy = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
21 |
+
SDP_IS_AVAILABLE = True
|
22 |
+
from torch.backends.cuda import SDPBackend, sdp_kernel
|
23 |
+
|
24 |
+
BACKEND_MAP = {
|
25 |
+
SDPBackend.MATH: {
|
26 |
+
"enable_math": True,
|
27 |
+
"enable_flash": False,
|
28 |
+
"enable_mem_efficient": False,
|
29 |
+
},
|
30 |
+
SDPBackend.FLASH_ATTENTION: {
|
31 |
+
"enable_math": False,
|
32 |
+
"enable_flash": True,
|
33 |
+
"enable_mem_efficient": False,
|
34 |
+
},
|
35 |
+
SDPBackend.EFFICIENT_ATTENTION: {
|
36 |
+
"enable_math": False,
|
37 |
+
"enable_flash": False,
|
38 |
+
"enable_mem_efficient": True,
|
39 |
+
},
|
40 |
+
None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
|
41 |
+
}
|
42 |
+
else:
|
43 |
+
from contextlib import nullcontext
|
44 |
+
|
45 |
+
SDP_IS_AVAILABLE = False
|
46 |
+
sdp_kernel = nullcontext
|
47 |
+
BACKEND_MAP = {}
|
48 |
+
logpy.warn(
|
49 |
+
f"No SDP backend available, likely because you are running in pytorch "
|
50 |
+
f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
|
51 |
+
f"You might want to consider upgrading."
|
52 |
+
)
|
53 |
+
|
54 |
+
try:
|
55 |
+
import xformers
|
56 |
+
import xformers.ops
|
57 |
+
|
58 |
+
XFORMERS_IS_AVAILABLE = True
|
59 |
+
except:
|
60 |
+
XFORMERS_IS_AVAILABLE = False
|
61 |
+
logpy.warn("no module 'xformers'. Processing without...")
|
62 |
+
|
63 |
+
# from .diffusionmodules.util import mixed_checkpoint as checkpoint
|
64 |
+
|
65 |
+
|
66 |
+
def exists(val):
|
67 |
+
return val is not None
|
68 |
+
|
69 |
+
|
70 |
+
def uniq(arr):
|
71 |
+
return {el: True for el in arr}.keys()
|
72 |
+
|
73 |
+
|
74 |
+
def default(val, d):
|
75 |
+
if exists(val):
|
76 |
+
return val
|
77 |
+
return d() if isfunction(d) else d
|
78 |
+
|
79 |
+
|
80 |
+
def max_neg_value(t):
|
81 |
+
return -torch.finfo(t.dtype).max
|
82 |
+
|
83 |
+
|
84 |
+
def init_(tensor):
|
85 |
+
dim = tensor.shape[-1]
|
86 |
+
std = 1 / math.sqrt(dim)
|
87 |
+
tensor.uniform_(-std, std)
|
88 |
+
return tensor
|
89 |
+
|
90 |
+
|
91 |
+
# feedforward
|
92 |
+
class GEGLU(nn.Module):
|
93 |
+
def __init__(self, dim_in, dim_out):
|
94 |
+
super().__init__()
|
95 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
96 |
+
|
97 |
+
def forward(self, x):
|
98 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
99 |
+
return x * F.gelu(gate)
|
100 |
+
|
101 |
+
|
102 |
+
class FeedForward(nn.Module):
|
103 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
104 |
+
super().__init__()
|
105 |
+
inner_dim = int(dim * mult)
|
106 |
+
dim_out = default(dim_out, dim)
|
107 |
+
project_in = (
|
108 |
+
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
109 |
+
if not glu
|
110 |
+
else GEGLU(dim, inner_dim)
|
111 |
+
)
|
112 |
+
|
113 |
+
self.net = nn.Sequential(
|
114 |
+
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
115 |
+
)
|
116 |
+
|
117 |
+
def forward(self, x):
|
118 |
+
return self.net(x)
|
119 |
+
|
120 |
+
|
121 |
+
def zero_module(module):
|
122 |
+
"""
|
123 |
+
Zero out the parameters of a module and return it.
|
124 |
+
"""
|
125 |
+
for p in module.parameters():
|
126 |
+
p.detach().zero_()
|
127 |
+
return module
|
128 |
+
|
129 |
+
|
130 |
+
def Normalize(in_channels):
|
131 |
+
return torch.nn.GroupNorm(
|
132 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
133 |
+
)
|
134 |
+
|
135 |
+
|
136 |
+
class LinearAttention(nn.Module):
|
137 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
138 |
+
super().__init__()
|
139 |
+
self.heads = heads
|
140 |
+
hidden_dim = dim_head * heads
|
141 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
142 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
b, c, h, w = x.shape
|
146 |
+
qkv = self.to_qkv(x)
|
147 |
+
q, k, v = rearrange(
|
148 |
+
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
|
149 |
+
)
|
150 |
+
k = k.softmax(dim=-1)
|
151 |
+
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
152 |
+
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
153 |
+
out = rearrange(
|
154 |
+
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
|
155 |
+
)
|
156 |
+
return self.to_out(out)
|
157 |
+
|
158 |
+
|
159 |
+
class SelfAttention(nn.Module):
|
160 |
+
ATTENTION_MODES = ("xformers", "torch", "math")
|
161 |
+
|
162 |
+
def __init__(
|
163 |
+
self,
|
164 |
+
dim: int,
|
165 |
+
num_heads: int = 8,
|
166 |
+
qkv_bias: bool = False,
|
167 |
+
qk_scale: Optional[float] = None,
|
168 |
+
attn_drop: float = 0.0,
|
169 |
+
proj_drop: float = 0.0,
|
170 |
+
attn_mode: str = "xformers",
|
171 |
+
):
|
172 |
+
super().__init__()
|
173 |
+
self.num_heads = num_heads
|
174 |
+
head_dim = dim // num_heads
|
175 |
+
self.scale = qk_scale or head_dim**-0.5
|
176 |
+
|
177 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
178 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
179 |
+
self.proj = nn.Linear(dim, dim)
|
180 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
181 |
+
assert attn_mode in self.ATTENTION_MODES
|
182 |
+
self.attn_mode = attn_mode
|
183 |
+
|
184 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
185 |
+
B, L, C = x.shape
|
186 |
+
|
187 |
+
qkv = self.qkv(x)
|
188 |
+
if self.attn_mode == "torch":
|
189 |
+
qkv = rearrange(
|
190 |
+
qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
|
191 |
+
).float()
|
192 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
193 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
194 |
+
x = rearrange(x, "B H L D -> B L (H D)")
|
195 |
+
elif self.attn_mode == "xformers":
|
196 |
+
qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
197 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
|
198 |
+
x = xformers.ops.memory_efficient_attention(q, k, v)
|
199 |
+
x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads)
|
200 |
+
elif self.attn_mode == "math":
|
201 |
+
qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
202 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
203 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
204 |
+
attn = attn.softmax(dim=-1)
|
205 |
+
attn = self.attn_drop(attn)
|
206 |
+
x = (attn @ v).transpose(1, 2).reshape(B, L, C)
|
207 |
+
else:
|
208 |
+
raise NotImplemented
|
209 |
+
|
210 |
+
x = self.proj(x)
|
211 |
+
x = self.proj_drop(x)
|
212 |
+
return x
|
213 |
+
|
214 |
+
|
215 |
+
class SpatialSelfAttention(nn.Module):
|
216 |
+
def __init__(self, in_channels):
|
217 |
+
super().__init__()
|
218 |
+
self.in_channels = in_channels
|
219 |
+
|
220 |
+
self.norm = Normalize(in_channels)
|
221 |
+
self.q = torch.nn.Conv2d(
|
222 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
223 |
+
)
|
224 |
+
self.k = torch.nn.Conv2d(
|
225 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
226 |
+
)
|
227 |
+
self.v = torch.nn.Conv2d(
|
228 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
229 |
+
)
|
230 |
+
self.proj_out = torch.nn.Conv2d(
|
231 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
232 |
+
)
|
233 |
+
|
234 |
+
def forward(self, x):
|
235 |
+
h_ = x
|
236 |
+
h_ = self.norm(h_)
|
237 |
+
q = self.q(h_)
|
238 |
+
k = self.k(h_)
|
239 |
+
v = self.v(h_)
|
240 |
+
|
241 |
+
# compute attention
|
242 |
+
b, c, h, w = q.shape
|
243 |
+
q = rearrange(q, "b c h w -> b (h w) c")
|
244 |
+
k = rearrange(k, "b c h w -> b c (h w)")
|
245 |
+
w_ = torch.einsum("bij,bjk->bik", q, k)
|
246 |
+
|
247 |
+
w_ = w_ * (int(c) ** (-0.5))
|
248 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
249 |
+
|
250 |
+
# attend to values
|
251 |
+
v = rearrange(v, "b c h w -> b c (h w)")
|
252 |
+
w_ = rearrange(w_, "b i j -> b j i")
|
253 |
+
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
254 |
+
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
255 |
+
h_ = self.proj_out(h_)
|
256 |
+
|
257 |
+
return x + h_
|
258 |
+
|
259 |
+
|
260 |
+
class CrossAttention(nn.Module):
|
261 |
+
def __init__(
|
262 |
+
self,
|
263 |
+
query_dim,
|
264 |
+
context_dim=None,
|
265 |
+
heads=8,
|
266 |
+
dim_head=64,
|
267 |
+
dropout=0.0,
|
268 |
+
backend=None,
|
269 |
+
):
|
270 |
+
super().__init__()
|
271 |
+
inner_dim = dim_head * heads
|
272 |
+
context_dim = default(context_dim, query_dim)
|
273 |
+
|
274 |
+
self.scale = dim_head**-0.5
|
275 |
+
self.heads = heads
|
276 |
+
|
277 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
278 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
279 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
280 |
+
|
281 |
+
self.to_out = nn.Sequential(
|
282 |
+
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
283 |
+
)
|
284 |
+
self.backend = backend
|
285 |
+
|
286 |
+
def forward(
|
287 |
+
self,
|
288 |
+
x,
|
289 |
+
context=None,
|
290 |
+
mask=None,
|
291 |
+
additional_tokens=None,
|
292 |
+
n_times_crossframe_attn_in_self=0,
|
293 |
+
):
|
294 |
+
h = self.heads
|
295 |
+
|
296 |
+
if additional_tokens is not None:
|
297 |
+
# get the number of masked tokens at the beginning of the output sequence
|
298 |
+
n_tokens_to_mask = additional_tokens.shape[1]
|
299 |
+
# add additional token
|
300 |
+
x = torch.cat([additional_tokens, x], dim=1)
|
301 |
+
|
302 |
+
q = self.to_q(x)
|
303 |
+
context = default(context, x)
|
304 |
+
k = self.to_k(context)
|
305 |
+
v = self.to_v(context)
|
306 |
+
|
307 |
+
if n_times_crossframe_attn_in_self:
|
308 |
+
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
|
309 |
+
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
|
310 |
+
n_cp = x.shape[0] // n_times_crossframe_attn_in_self
|
311 |
+
k = repeat(
|
312 |
+
k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
|
313 |
+
)
|
314 |
+
v = repeat(
|
315 |
+
v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
|
316 |
+
)
|
317 |
+
|
318 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
319 |
+
|
320 |
+
## old
|
321 |
+
"""
|
322 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
323 |
+
del q, k
|
324 |
+
|
325 |
+
if exists(mask):
|
326 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
327 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
328 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
329 |
+
sim.masked_fill_(~mask, max_neg_value)
|
330 |
+
|
331 |
+
# attention, what we cannot get enough of
|
332 |
+
sim = sim.softmax(dim=-1)
|
333 |
+
|
334 |
+
out = einsum('b i j, b j d -> b i d', sim, v)
|
335 |
+
"""
|
336 |
+
## new
|
337 |
+
with sdp_kernel(**BACKEND_MAP[self.backend]):
|
338 |
+
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
339 |
+
out = F.scaled_dot_product_attention(
|
340 |
+
q, k, v, attn_mask=mask
|
341 |
+
) # scale is dim_head ** -0.5 per default
|
342 |
+
|
343 |
+
del q, k, v
|
344 |
+
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
345 |
+
|
346 |
+
if additional_tokens is not None:
|
347 |
+
# remove additional token
|
348 |
+
out = out[:, n_tokens_to_mask:]
|
349 |
+
return self.to_out(out)
|
350 |
+
|
351 |
+
|
352 |
+
class MemoryEfficientCrossAttention(nn.Module):
|
353 |
+
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
354 |
+
def __init__(
|
355 |
+
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
|
356 |
+
):
|
357 |
+
super().__init__()
|
358 |
+
logpy.debug(
|
359 |
+
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
|
360 |
+
f"context_dim is {context_dim} and using {heads} heads with a "
|
361 |
+
f"dimension of {dim_head}."
|
362 |
+
)
|
363 |
+
inner_dim = dim_head * heads
|
364 |
+
context_dim = default(context_dim, query_dim)
|
365 |
+
|
366 |
+
self.heads = heads
|
367 |
+
self.dim_head = dim_head
|
368 |
+
|
369 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
370 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
371 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
372 |
+
|
373 |
+
self.to_out = nn.Sequential(
|
374 |
+
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
375 |
+
)
|
376 |
+
self.attention_op: Optional[Any] = None
|
377 |
+
|
378 |
+
def forward(
|
379 |
+
self,
|
380 |
+
x,
|
381 |
+
context=None,
|
382 |
+
mask=None,
|
383 |
+
additional_tokens=None,
|
384 |
+
n_times_crossframe_attn_in_self=0,
|
385 |
+
):
|
386 |
+
if additional_tokens is not None:
|
387 |
+
# get the number of masked tokens at the beginning of the output sequence
|
388 |
+
n_tokens_to_mask = additional_tokens.shape[1]
|
389 |
+
# add additional token
|
390 |
+
x = torch.cat([additional_tokens, x], dim=1)
|
391 |
+
q = self.to_q(x)
|
392 |
+
context = default(context, x)
|
393 |
+
k = self.to_k(context)
|
394 |
+
v = self.to_v(context)
|
395 |
+
|
396 |
+
if n_times_crossframe_attn_in_self:
|
397 |
+
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
|
398 |
+
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
|
399 |
+
# n_cp = x.shape[0]//n_times_crossframe_attn_in_self
|
400 |
+
k = repeat(
|
401 |
+
k[::n_times_crossframe_attn_in_self],
|
402 |
+
"b ... -> (b n) ...",
|
403 |
+
n=n_times_crossframe_attn_in_self,
|
404 |
+
)
|
405 |
+
v = repeat(
|
406 |
+
v[::n_times_crossframe_attn_in_self],
|
407 |
+
"b ... -> (b n) ...",
|
408 |
+
n=n_times_crossframe_attn_in_self,
|
409 |
+
)
|
410 |
+
|
411 |
+
b, _, _ = q.shape
|
412 |
+
q, k, v = map(
|
413 |
+
lambda t: t.unsqueeze(3)
|
414 |
+
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
415 |
+
.permute(0, 2, 1, 3)
|
416 |
+
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
417 |
+
.contiguous(),
|
418 |
+
(q, k, v),
|
419 |
+
)
|
420 |
+
|
421 |
+
# actually compute the attention, what we cannot get enough of
|
422 |
+
if version.parse(xformers.__version__) >= version.parse("0.0.21"):
|
423 |
+
# NOTE: workaround for
|
424 |
+
# https://github.com/facebookresearch/xformers/issues/845
|
425 |
+
max_bs = 32768
|
426 |
+
N = q.shape[0]
|
427 |
+
n_batches = math.ceil(N / max_bs)
|
428 |
+
out = list()
|
429 |
+
for i_batch in range(n_batches):
|
430 |
+
batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)
|
431 |
+
out.append(
|
432 |
+
xformers.ops.memory_efficient_attention(
|
433 |
+
q[batch],
|
434 |
+
k[batch],
|
435 |
+
v[batch],
|
436 |
+
attn_bias=None,
|
437 |
+
op=self.attention_op,
|
438 |
+
)
|
439 |
+
)
|
440 |
+
out = torch.cat(out, 0)
|
441 |
+
else:
|
442 |
+
out = xformers.ops.memory_efficient_attention(
|
443 |
+
q, k, v, attn_bias=None, op=self.attention_op
|
444 |
+
)
|
445 |
+
|
446 |
+
# TODO: Use this directly in the attention operation, as a bias
|
447 |
+
if exists(mask):
|
448 |
+
raise NotImplementedError
|
449 |
+
out = (
|
450 |
+
out.unsqueeze(0)
|
451 |
+
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
452 |
+
.permute(0, 2, 1, 3)
|
453 |
+
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
454 |
+
)
|
455 |
+
if additional_tokens is not None:
|
456 |
+
# remove additional token
|
457 |
+
out = out[:, n_tokens_to_mask:]
|
458 |
+
return self.to_out(out)
|
459 |
+
|
460 |
+
|
461 |
+
class BasicTransformerBlock(nn.Module):
|
462 |
+
ATTENTION_MODES = {
|
463 |
+
"softmax": CrossAttention, # vanilla attention
|
464 |
+
"softmax-xformers": MemoryEfficientCrossAttention, # ampere
|
465 |
+
}
|
466 |
+
|
467 |
+
def __init__(
|
468 |
+
self,
|
469 |
+
dim,
|
470 |
+
n_heads,
|
471 |
+
d_head,
|
472 |
+
dropout=0.0,
|
473 |
+
context_dim=None,
|
474 |
+
gated_ff=True,
|
475 |
+
checkpoint=True,
|
476 |
+
disable_self_attn=False,
|
477 |
+
attn_mode="softmax",
|
478 |
+
sdp_backend=None,
|
479 |
+
):
|
480 |
+
super().__init__()
|
481 |
+
assert attn_mode in self.ATTENTION_MODES
|
482 |
+
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
|
483 |
+
logpy.warn(
|
484 |
+
f"Attention mode '{attn_mode}' is not available. Falling "
|
485 |
+
f"back to native attention. This is not a problem in "
|
486 |
+
f"Pytorch >= 2.0. FYI, you are running with PyTorch "
|
487 |
+
f"version {torch.__version__}."
|
488 |
+
)
|
489 |
+
attn_mode = "softmax"
|
490 |
+
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
|
491 |
+
logpy.warn(
|
492 |
+
"We do not support vanilla attention anymore, as it is too "
|
493 |
+
"expensive. Sorry."
|
494 |
+
)
|
495 |
+
if not XFORMERS_IS_AVAILABLE:
|
496 |
+
assert (
|
497 |
+
False
|
498 |
+
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
499 |
+
else:
|
500 |
+
logpy.info("Falling back to xformers efficient attention.")
|
501 |
+
attn_mode = "softmax-xformers"
|
502 |
+
attn_cls = self.ATTENTION_MODES[attn_mode]
|
503 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
504 |
+
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
|
505 |
+
else:
|
506 |
+
assert sdp_backend is None
|
507 |
+
self.disable_self_attn = disable_self_attn
|
508 |
+
self.attn1 = attn_cls(
|
509 |
+
query_dim=dim,
|
510 |
+
heads=n_heads,
|
511 |
+
dim_head=d_head,
|
512 |
+
dropout=dropout,
|
513 |
+
context_dim=context_dim if self.disable_self_attn else None,
|
514 |
+
backend=sdp_backend,
|
515 |
+
) # is a self-attention if not self.disable_self_attn
|
516 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
517 |
+
self.attn2 = attn_cls(
|
518 |
+
query_dim=dim,
|
519 |
+
context_dim=context_dim,
|
520 |
+
heads=n_heads,
|
521 |
+
dim_head=d_head,
|
522 |
+
dropout=dropout,
|
523 |
+
backend=sdp_backend,
|
524 |
+
) # is self-attn if context is none
|
525 |
+
self.norm1 = nn.LayerNorm(dim)
|
526 |
+
self.norm2 = nn.LayerNorm(dim)
|
527 |
+
self.norm3 = nn.LayerNorm(dim)
|
528 |
+
self.checkpoint = checkpoint
|
529 |
+
if self.checkpoint:
|
530 |
+
logpy.debug(f"{self.__class__.__name__} is using checkpointing")
|
531 |
+
|
532 |
+
def forward(
|
533 |
+
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
534 |
+
):
|
535 |
+
kwargs = {"x": x}
|
536 |
+
|
537 |
+
if context is not None:
|
538 |
+
kwargs.update({"context": context})
|
539 |
+
|
540 |
+
if additional_tokens is not None:
|
541 |
+
kwargs.update({"additional_tokens": additional_tokens})
|
542 |
+
|
543 |
+
if n_times_crossframe_attn_in_self:
|
544 |
+
kwargs.update(
|
545 |
+
{"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
|
546 |
+
)
|
547 |
+
|
548 |
+
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
|
549 |
+
if self.checkpoint:
|
550 |
+
# inputs = {"x": x, "context": context}
|
551 |
+
return checkpoint(self._forward, x, context)
|
552 |
+
# return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
|
553 |
+
else:
|
554 |
+
return self._forward(**kwargs)
|
555 |
+
|
556 |
+
def _forward(
|
557 |
+
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
558 |
+
):
|
559 |
+
x = (
|
560 |
+
self.attn1(
|
561 |
+
self.norm1(x),
|
562 |
+
context=context if self.disable_self_attn else None,
|
563 |
+
additional_tokens=additional_tokens,
|
564 |
+
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
|
565 |
+
if not self.disable_self_attn
|
566 |
+
else 0,
|
567 |
+
)
|
568 |
+
+ x
|
569 |
+
)
|
570 |
+
x = (
|
571 |
+
self.attn2(
|
572 |
+
self.norm2(x), context=context, additional_tokens=additional_tokens
|
573 |
+
)
|
574 |
+
+ x
|
575 |
+
)
|
576 |
+
x = self.ff(self.norm3(x)) + x
|
577 |
+
return x
|
578 |
+
|
579 |
+
|
580 |
+
class BasicTransformerSingleLayerBlock(nn.Module):
|
581 |
+
ATTENTION_MODES = {
|
582 |
+
"softmax": CrossAttention, # vanilla attention
|
583 |
+
"softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
|
584 |
+
# (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
|
585 |
+
}
|
586 |
+
|
587 |
+
def __init__(
|
588 |
+
self,
|
589 |
+
dim,
|
590 |
+
n_heads,
|
591 |
+
d_head,
|
592 |
+
dropout=0.0,
|
593 |
+
context_dim=None,
|
594 |
+
gated_ff=True,
|
595 |
+
checkpoint=True,
|
596 |
+
attn_mode="softmax",
|
597 |
+
):
|
598 |
+
super().__init__()
|
599 |
+
assert attn_mode in self.ATTENTION_MODES
|
600 |
+
attn_cls = self.ATTENTION_MODES[attn_mode]
|
601 |
+
self.attn1 = attn_cls(
|
602 |
+
query_dim=dim,
|
603 |
+
heads=n_heads,
|
604 |
+
dim_head=d_head,
|
605 |
+
dropout=dropout,
|
606 |
+
context_dim=context_dim,
|
607 |
+
)
|
608 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
609 |
+
self.norm1 = nn.LayerNorm(dim)
|
610 |
+
self.norm2 = nn.LayerNorm(dim)
|
611 |
+
self.checkpoint = checkpoint
|
612 |
+
|
613 |
+
def forward(self, x, context=None):
|
614 |
+
# inputs = {"x": x, "context": context}
|
615 |
+
# return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
|
616 |
+
return checkpoint(self._forward, x, context)
|
617 |
+
|
618 |
+
def _forward(self, x, context=None):
|
619 |
+
x = self.attn1(self.norm1(x), context=context) + x
|
620 |
+
x = self.ff(self.norm2(x)) + x
|
621 |
+
return x
|
622 |
+
|
623 |
+
|
624 |
+
class SpatialTransformer(nn.Module):
|
625 |
+
"""
|
626 |
+
Transformer block for image-like data.
|
627 |
+
First, project the input (aka embedding)
|
628 |
+
and reshape to b, t, d.
|
629 |
+
Then apply standard transformer action.
|
630 |
+
Finally, reshape to image
|
631 |
+
NEW: use_linear for more efficiency instead of the 1x1 convs
|
632 |
+
"""
|
633 |
+
|
634 |
+
def __init__(
|
635 |
+
self,
|
636 |
+
in_channels,
|
637 |
+
n_heads,
|
638 |
+
d_head,
|
639 |
+
depth=1,
|
640 |
+
dropout=0.0,
|
641 |
+
context_dim=None,
|
642 |
+
disable_self_attn=False,
|
643 |
+
use_linear=False,
|
644 |
+
attn_type="softmax",
|
645 |
+
use_checkpoint=True,
|
646 |
+
# sdp_backend=SDPBackend.FLASH_ATTENTION
|
647 |
+
sdp_backend=None,
|
648 |
+
):
|
649 |
+
super().__init__()
|
650 |
+
logpy.debug(
|
651 |
+
f"constructing {self.__class__.__name__} of depth {depth} w/ "
|
652 |
+
f"{in_channels} channels and {n_heads} heads."
|
653 |
+
)
|
654 |
+
|
655 |
+
if exists(context_dim) and not isinstance(context_dim, list):
|
656 |
+
context_dim = [context_dim]
|
657 |
+
if exists(context_dim) and isinstance(context_dim, list):
|
658 |
+
if depth != len(context_dim):
|
659 |
+
logpy.warn(
|
660 |
+
f"{self.__class__.__name__}: Found context dims "
|
661 |
+
f"{context_dim} of depth {len(context_dim)}, which does not "
|
662 |
+
f"match the specified 'depth' of {depth}. Setting context_dim "
|
663 |
+
f"to {depth * [context_dim[0]]} now."
|
664 |
+
)
|
665 |
+
# depth does not match context dims.
|
666 |
+
assert all(
|
667 |
+
map(lambda x: x == context_dim[0], context_dim)
|
668 |
+
), "need homogenous context_dim to match depth automatically"
|
669 |
+
context_dim = depth * [context_dim[0]]
|
670 |
+
elif context_dim is None:
|
671 |
+
context_dim = [None] * depth
|
672 |
+
self.in_channels = in_channels
|
673 |
+
inner_dim = n_heads * d_head
|
674 |
+
self.norm = Normalize(in_channels)
|
675 |
+
if not use_linear:
|
676 |
+
self.proj_in = nn.Conv2d(
|
677 |
+
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
678 |
+
)
|
679 |
+
else:
|
680 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
681 |
+
|
682 |
+
self.transformer_blocks = nn.ModuleList(
|
683 |
+
[
|
684 |
+
BasicTransformerBlock(
|
685 |
+
inner_dim,
|
686 |
+
n_heads,
|
687 |
+
d_head,
|
688 |
+
dropout=dropout,
|
689 |
+
context_dim=context_dim[d],
|
690 |
+
disable_self_attn=disable_self_attn,
|
691 |
+
attn_mode=attn_type,
|
692 |
+
checkpoint=use_checkpoint,
|
693 |
+
sdp_backend=sdp_backend,
|
694 |
+
)
|
695 |
+
for d in range(depth)
|
696 |
+
]
|
697 |
+
)
|
698 |
+
if not use_linear:
|
699 |
+
self.proj_out = zero_module(
|
700 |
+
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
701 |
+
)
|
702 |
+
else:
|
703 |
+
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
704 |
+
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
705 |
+
self.use_linear = use_linear
|
706 |
+
|
707 |
+
def forward(self, x, context=None):
|
708 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
709 |
+
if not isinstance(context, list):
|
710 |
+
context = [context]
|
711 |
+
b, c, h, w = x.shape
|
712 |
+
x_in = x
|
713 |
+
x = self.norm(x)
|
714 |
+
if not self.use_linear:
|
715 |
+
x = self.proj_in(x)
|
716 |
+
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
717 |
+
if self.use_linear:
|
718 |
+
x = self.proj_in(x)
|
719 |
+
for i, block in enumerate(self.transformer_blocks):
|
720 |
+
if i > 0 and len(context) == 1:
|
721 |
+
i = 0 # use same context for each block
|
722 |
+
x = block(x, context=context[i])
|
723 |
+
if self.use_linear:
|
724 |
+
x = self.proj_out(x)
|
725 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
726 |
+
if not self.use_linear:
|
727 |
+
x = self.proj_out(x)
|
728 |
+
return x + x_in
|
729 |
+
|
730 |
+
|
731 |
+
class SimpleTransformer(nn.Module):
|
732 |
+
def __init__(
|
733 |
+
self,
|
734 |
+
dim: int,
|
735 |
+
depth: int,
|
736 |
+
heads: int,
|
737 |
+
dim_head: int,
|
738 |
+
context_dim: Optional[int] = None,
|
739 |
+
dropout: float = 0.0,
|
740 |
+
checkpoint: bool = True,
|
741 |
+
):
|
742 |
+
super().__init__()
|
743 |
+
self.layers = nn.ModuleList([])
|
744 |
+
for _ in range(depth):
|
745 |
+
self.layers.append(
|
746 |
+
BasicTransformerBlock(
|
747 |
+
dim,
|
748 |
+
heads,
|
749 |
+
dim_head,
|
750 |
+
dropout=dropout,
|
751 |
+
context_dim=context_dim,
|
752 |
+
attn_mode="softmax-xformers",
|
753 |
+
checkpoint=checkpoint,
|
754 |
+
)
|
755 |
+
)
|
756 |
+
|
757 |
+
def forward(
|
758 |
+
self,
|
759 |
+
x: torch.Tensor,
|
760 |
+
context: Optional[torch.Tensor] = None,
|
761 |
+
) -> torch.Tensor:
|
762 |
+
for layer in self.layers:
|
763 |
+
x = layer(x, context)
|
764 |
+
return x
|
sgm/modules/autoencoding/__init__.py
ADDED
File without changes
|
sgm/modules/autoencoding/losses/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__all__ = [
|
2 |
+
"GeneralLPIPSWithDiscriminator",
|
3 |
+
"LatentLPIPS",
|
4 |
+
]
|
5 |
+
|
6 |
+
from .discriminator_loss import GeneralLPIPSWithDiscriminator
|
7 |
+
from .lpips import LatentLPIPS
|
sgm/modules/autoencoding/losses/discriminator_loss.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torchvision
|
7 |
+
from einops import rearrange
|
8 |
+
from matplotlib import colormaps
|
9 |
+
from matplotlib import pyplot as plt
|
10 |
+
|
11 |
+
from ....util import default, instantiate_from_config
|
12 |
+
from ..lpips.loss.lpips import LPIPS
|
13 |
+
from ..lpips.model.model import weights_init
|
14 |
+
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
|
15 |
+
|
16 |
+
|
17 |
+
class GeneralLPIPSWithDiscriminator(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
disc_start: int,
|
21 |
+
logvar_init: float = 0.0,
|
22 |
+
disc_num_layers: int = 3,
|
23 |
+
disc_in_channels: int = 3,
|
24 |
+
disc_factor: float = 1.0,
|
25 |
+
disc_weight: float = 1.0,
|
26 |
+
perceptual_weight: float = 1.0,
|
27 |
+
disc_loss: str = "hinge",
|
28 |
+
scale_input_to_tgt_size: bool = False,
|
29 |
+
dims: int = 2,
|
30 |
+
learn_logvar: bool = False,
|
31 |
+
regularization_weights: Union[None, Dict[str, float]] = None,
|
32 |
+
additional_log_keys: Optional[List[str]] = None,
|
33 |
+
discriminator_config: Optional[Dict] = None,
|
34 |
+
):
|
35 |
+
super().__init__()
|
36 |
+
self.dims = dims
|
37 |
+
if self.dims > 2:
|
38 |
+
print(
|
39 |
+
f"running with dims={dims}. This means that for perceptual loss "
|
40 |
+
f"calculation, the LPIPS loss will be applied to each frame "
|
41 |
+
f"independently."
|
42 |
+
)
|
43 |
+
self.scale_input_to_tgt_size = scale_input_to_tgt_size
|
44 |
+
assert disc_loss in ["hinge", "vanilla"]
|
45 |
+
self.perceptual_loss = LPIPS().eval()
|
46 |
+
self.perceptual_weight = perceptual_weight
|
47 |
+
# output log variance
|
48 |
+
self.logvar = nn.Parameter(
|
49 |
+
torch.full((), logvar_init), requires_grad=learn_logvar
|
50 |
+
)
|
51 |
+
self.learn_logvar = learn_logvar
|
52 |
+
|
53 |
+
discriminator_config = default(
|
54 |
+
discriminator_config,
|
55 |
+
{
|
56 |
+
"target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator",
|
57 |
+
"params": {
|
58 |
+
"input_nc": disc_in_channels,
|
59 |
+
"n_layers": disc_num_layers,
|
60 |
+
"use_actnorm": False,
|
61 |
+
},
|
62 |
+
},
|
63 |
+
)
|
64 |
+
|
65 |
+
self.discriminator = instantiate_from_config(discriminator_config).apply(
|
66 |
+
weights_init
|
67 |
+
)
|
68 |
+
self.discriminator_iter_start = disc_start
|
69 |
+
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
|
70 |
+
self.disc_factor = disc_factor
|
71 |
+
self.discriminator_weight = disc_weight
|
72 |
+
self.regularization_weights = default(regularization_weights, {})
|
73 |
+
|
74 |
+
self.forward_keys = [
|
75 |
+
"optimizer_idx",
|
76 |
+
"global_step",
|
77 |
+
"last_layer",
|
78 |
+
"split",
|
79 |
+
"regularization_log",
|
80 |
+
]
|
81 |
+
|
82 |
+
self.additional_log_keys = set(default(additional_log_keys, []))
|
83 |
+
self.additional_log_keys.update(set(self.regularization_weights.keys()))
|
84 |
+
|
85 |
+
def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
|
86 |
+
return self.discriminator.parameters()
|
87 |
+
|
88 |
+
def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]:
|
89 |
+
if self.learn_logvar:
|
90 |
+
yield self.logvar
|
91 |
+
yield from ()
|
92 |
+
|
93 |
+
@torch.no_grad()
|
94 |
+
def log_images(
|
95 |
+
self, inputs: torch.Tensor, reconstructions: torch.Tensor
|
96 |
+
) -> Dict[str, torch.Tensor]:
|
97 |
+
# calc logits of real/fake
|
98 |
+
logits_real = self.discriminator(inputs.contiguous().detach())
|
99 |
+
if len(logits_real.shape) < 4:
|
100 |
+
# Non patch-discriminator
|
101 |
+
return dict()
|
102 |
+
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
103 |
+
# -> (b, 1, h, w)
|
104 |
+
|
105 |
+
# parameters for colormapping
|
106 |
+
high = max(logits_fake.abs().max(), logits_real.abs().max()).item()
|
107 |
+
cmap = colormaps["PiYG"] # diverging colormap
|
108 |
+
|
109 |
+
def to_colormap(logits: torch.Tensor) -> torch.Tensor:
|
110 |
+
"""(b, 1, ...) -> (b, 3, ...)"""
|
111 |
+
logits = (logits + high) / (2 * high)
|
112 |
+
logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel
|
113 |
+
# -> (b, 1, ..., 3)
|
114 |
+
logits = torch.from_numpy(logits_np).to(logits.device)
|
115 |
+
return rearrange(logits, "b 1 ... c -> b c ...")
|
116 |
+
|
117 |
+
logits_real = torch.nn.functional.interpolate(
|
118 |
+
logits_real,
|
119 |
+
size=inputs.shape[-2:],
|
120 |
+
mode="nearest",
|
121 |
+
antialias=False,
|
122 |
+
)
|
123 |
+
logits_fake = torch.nn.functional.interpolate(
|
124 |
+
logits_fake,
|
125 |
+
size=reconstructions.shape[-2:],
|
126 |
+
mode="nearest",
|
127 |
+
antialias=False,
|
128 |
+
)
|
129 |
+
|
130 |
+
# alpha value of logits for overlay
|
131 |
+
alpha_real = torch.abs(logits_real) / high
|
132 |
+
alpha_fake = torch.abs(logits_fake) / high
|
133 |
+
# -> (b, 1, h, w) in range [0, 0.5]
|
134 |
+
# alpha value of lines don't really matter, since the values are the same
|
135 |
+
# for both images and logits anyway
|
136 |
+
grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4)
|
137 |
+
grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4)
|
138 |
+
grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1)
|
139 |
+
# -> (1, h, w)
|
140 |
+
# blend logits and images together
|
141 |
+
|
142 |
+
# prepare logits for plotting
|
143 |
+
logits_real = to_colormap(logits_real)
|
144 |
+
logits_fake = to_colormap(logits_fake)
|
145 |
+
# resize logits
|
146 |
+
# -> (b, 3, h, w)
|
147 |
+
|
148 |
+
# make some grids
|
149 |
+
# add all logits to one plot
|
150 |
+
logits_real = torchvision.utils.make_grid(logits_real, nrow=4)
|
151 |
+
logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4)
|
152 |
+
# I just love how torchvision calls the number of columns `nrow`
|
153 |
+
grid_logits = torch.cat((logits_real, logits_fake), dim=1)
|
154 |
+
# -> (3, h, w)
|
155 |
+
|
156 |
+
grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4)
|
157 |
+
grid_images_fake = torchvision.utils.make_grid(
|
158 |
+
0.5 * reconstructions + 0.5, nrow=4
|
159 |
+
)
|
160 |
+
grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1)
|
161 |
+
# -> (3, h, w) in range [0, 1]
|
162 |
+
|
163 |
+
grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images
|
164 |
+
|
165 |
+
# Create labeled colorbar
|
166 |
+
dpi = 100
|
167 |
+
height = 128 / dpi
|
168 |
+
width = grid_logits.shape[2] / dpi
|
169 |
+
fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
|
170 |
+
img = ax.imshow(np.array([[-high, high]]), cmap=cmap)
|
171 |
+
plt.colorbar(
|
172 |
+
img,
|
173 |
+
cax=ax,
|
174 |
+
orientation="horizontal",
|
175 |
+
fraction=0.9,
|
176 |
+
aspect=width / height,
|
177 |
+
pad=0.0,
|
178 |
+
)
|
179 |
+
img.set_visible(False)
|
180 |
+
fig.tight_layout()
|
181 |
+
fig.canvas.draw()
|
182 |
+
# manually convert figure to numpy
|
183 |
+
cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
184 |
+
cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
185 |
+
cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0
|
186 |
+
cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device)
|
187 |
+
|
188 |
+
# Add colorbar to plot
|
189 |
+
annotated_grid = torch.cat((grid_logits, cbar), dim=1)
|
190 |
+
blended_grid = torch.cat((grid_blend, cbar), dim=1)
|
191 |
+
return {
|
192 |
+
"vis_logits": 2 * annotated_grid[None, ...] - 1,
|
193 |
+
"vis_logits_blended": 2 * blended_grid[None, ...] - 1,
|
194 |
+
}
|
195 |
+
|
196 |
+
def calculate_adaptive_weight(
|
197 |
+
self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor
|
198 |
+
) -> torch.Tensor:
|
199 |
+
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
200 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
201 |
+
|
202 |
+
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
203 |
+
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
204 |
+
d_weight = d_weight * self.discriminator_weight
|
205 |
+
return d_weight
|
206 |
+
|
207 |
+
def forward(
|
208 |
+
self,
|
209 |
+
inputs: torch.Tensor,
|
210 |
+
reconstructions: torch.Tensor,
|
211 |
+
*, # added because I changed the order here
|
212 |
+
regularization_log: Dict[str, torch.Tensor],
|
213 |
+
optimizer_idx: int,
|
214 |
+
global_step: int,
|
215 |
+
last_layer: torch.Tensor,
|
216 |
+
split: str = "train",
|
217 |
+
weights: Union[None, float, torch.Tensor] = None,
|
218 |
+
) -> Tuple[torch.Tensor, dict]:
|
219 |
+
if self.scale_input_to_tgt_size:
|
220 |
+
inputs = torch.nn.functional.interpolate(
|
221 |
+
inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
|
222 |
+
)
|
223 |
+
|
224 |
+
if self.dims > 2:
|
225 |
+
inputs, reconstructions = map(
|
226 |
+
lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
|
227 |
+
(inputs, reconstructions),
|
228 |
+
)
|
229 |
+
|
230 |
+
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
231 |
+
if self.perceptual_weight > 0:
|
232 |
+
p_loss = self.perceptual_loss(
|
233 |
+
inputs.contiguous(), reconstructions.contiguous()
|
234 |
+
)
|
235 |
+
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
236 |
+
|
237 |
+
nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
|
238 |
+
|
239 |
+
# now the GAN part
|
240 |
+
if optimizer_idx == 0:
|
241 |
+
# generator update
|
242 |
+
if global_step >= self.discriminator_iter_start or not self.training:
|
243 |
+
logits_fake = self.discriminator(reconstructions.contiguous())
|
244 |
+
g_loss = -torch.mean(logits_fake)
|
245 |
+
if self.training:
|
246 |
+
d_weight = self.calculate_adaptive_weight(
|
247 |
+
nll_loss, g_loss, last_layer=last_layer
|
248 |
+
)
|
249 |
+
else:
|
250 |
+
d_weight = torch.tensor(1.0)
|
251 |
+
else:
|
252 |
+
d_weight = torch.tensor(0.0)
|
253 |
+
g_loss = torch.tensor(0.0, requires_grad=True)
|
254 |
+
|
255 |
+
loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss
|
256 |
+
log = dict()
|
257 |
+
for k in regularization_log:
|
258 |
+
if k in self.regularization_weights:
|
259 |
+
loss = loss + self.regularization_weights[k] * regularization_log[k]
|
260 |
+
if k in self.additional_log_keys:
|
261 |
+
log[f"{split}/{k}"] = regularization_log[k].detach().float().mean()
|
262 |
+
|
263 |
+
log.update(
|
264 |
+
{
|
265 |
+
f"{split}/loss/total": loss.clone().detach().mean(),
|
266 |
+
f"{split}/loss/nll": nll_loss.detach().mean(),
|
267 |
+
f"{split}/loss/rec": rec_loss.detach().mean(),
|
268 |
+
f"{split}/loss/g": g_loss.detach().mean(),
|
269 |
+
f"{split}/scalars/logvar": self.logvar.detach(),
|
270 |
+
f"{split}/scalars/d_weight": d_weight.detach(),
|
271 |
+
}
|
272 |
+
)
|
273 |
+
|
274 |
+
return loss, log
|
275 |
+
elif optimizer_idx == 1:
|
276 |
+
# second pass for discriminator update
|
277 |
+
logits_real = self.discriminator(inputs.contiguous().detach())
|
278 |
+
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
279 |
+
|
280 |
+
if global_step >= self.discriminator_iter_start or not self.training:
|
281 |
+
d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake)
|
282 |
+
else:
|
283 |
+
d_loss = torch.tensor(0.0, requires_grad=True)
|
284 |
+
|
285 |
+
log = {
|
286 |
+
f"{split}/loss/disc": d_loss.clone().detach().mean(),
|
287 |
+
f"{split}/logits/real": logits_real.detach().mean(),
|
288 |
+
f"{split}/logits/fake": logits_fake.detach().mean(),
|
289 |
+
}
|
290 |
+
return d_loss, log
|
291 |
+
else:
|
292 |
+
raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}")
|
293 |
+
|
294 |
+
def get_nll_loss(
|
295 |
+
self,
|
296 |
+
rec_loss: torch.Tensor,
|
297 |
+
weights: Optional[Union[float, torch.Tensor]] = None,
|
298 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
299 |
+
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
300 |
+
weighted_nll_loss = nll_loss
|
301 |
+
if weights is not None:
|
302 |
+
weighted_nll_loss = weights * nll_loss
|
303 |
+
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
304 |
+
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
305 |
+
|
306 |
+
return nll_loss, weighted_nll_loss
|
sgm/modules/autoencoding/losses/lpips.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from ....util import default, instantiate_from_config
|
5 |
+
from ..lpips.loss.lpips import LPIPS
|
6 |
+
|
7 |
+
|
8 |
+
class LatentLPIPS(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
decoder_config,
|
12 |
+
perceptual_weight=1.0,
|
13 |
+
latent_weight=1.0,
|
14 |
+
scale_input_to_tgt_size=False,
|
15 |
+
scale_tgt_to_input_size=False,
|
16 |
+
perceptual_weight_on_inputs=0.0,
|
17 |
+
):
|
18 |
+
super().__init__()
|
19 |
+
self.scale_input_to_tgt_size = scale_input_to_tgt_size
|
20 |
+
self.scale_tgt_to_input_size = scale_tgt_to_input_size
|
21 |
+
self.init_decoder(decoder_config)
|
22 |
+
self.perceptual_loss = LPIPS().eval()
|
23 |
+
self.perceptual_weight = perceptual_weight
|
24 |
+
self.latent_weight = latent_weight
|
25 |
+
self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
|
26 |
+
|
27 |
+
def init_decoder(self, config):
|
28 |
+
self.decoder = instantiate_from_config(config)
|
29 |
+
if hasattr(self.decoder, "encoder"):
|
30 |
+
del self.decoder.encoder
|
31 |
+
|
32 |
+
def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
|
33 |
+
log = dict()
|
34 |
+
loss = (latent_inputs - latent_predictions) ** 2
|
35 |
+
log[f"{split}/latent_l2_loss"] = loss.mean().detach()
|
36 |
+
image_reconstructions = None
|
37 |
+
if self.perceptual_weight > 0.0:
|
38 |
+
image_reconstructions = self.decoder.decode(latent_predictions)
|
39 |
+
image_targets = self.decoder.decode(latent_inputs)
|
40 |
+
perceptual_loss = self.perceptual_loss(
|
41 |
+
image_targets.contiguous(), image_reconstructions.contiguous()
|
42 |
+
)
|
43 |
+
loss = (
|
44 |
+
self.latent_weight * loss.mean()
|
45 |
+
+ self.perceptual_weight * perceptual_loss.mean()
|
46 |
+
)
|
47 |
+
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
|
48 |
+
|
49 |
+
if self.perceptual_weight_on_inputs > 0.0:
|
50 |
+
image_reconstructions = default(
|
51 |
+
image_reconstructions, self.decoder.decode(latent_predictions)
|
52 |
+
)
|
53 |
+
if self.scale_input_to_tgt_size:
|
54 |
+
image_inputs = torch.nn.functional.interpolate(
|
55 |
+
image_inputs,
|
56 |
+
image_reconstructions.shape[2:],
|
57 |
+
mode="bicubic",
|
58 |
+
antialias=True,
|
59 |
+
)
|
60 |
+
elif self.scale_tgt_to_input_size:
|
61 |
+
image_reconstructions = torch.nn.functional.interpolate(
|
62 |
+
image_reconstructions,
|
63 |
+
image_inputs.shape[2:],
|
64 |
+
mode="bicubic",
|
65 |
+
antialias=True,
|
66 |
+
)
|
67 |
+
|
68 |
+
perceptual_loss2 = self.perceptual_loss(
|
69 |
+
image_inputs.contiguous(), image_reconstructions.contiguous()
|
70 |
+
)
|
71 |
+
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
|
72 |
+
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
|
73 |
+
return loss, log
|
sgm/modules/autoencoding/lpips/__init__.py
ADDED
File without changes
|
sgm/modules/autoencoding/lpips/loss/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
vgg.pth
|
sgm/modules/autoencoding/lpips/loss/LICENSE
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
|
2 |
+
All rights reserved.
|
3 |
+
|
4 |
+
Redistribution and use in source and binary forms, with or without
|
5 |
+
modification, are permitted provided that the following conditions are met:
|
6 |
+
|
7 |
+
* Redistributions of source code must retain the above copyright notice, this
|
8 |
+
list of conditions and the following disclaimer.
|
9 |
+
|
10 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
11 |
+
this list of conditions and the following disclaimer in the documentation
|
12 |
+
and/or other materials provided with the distribution.
|
13 |
+
|
14 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
15 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
16 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
17 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
18 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
19 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
20 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
21 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
22 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
23 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
sgm/modules/autoencoding/lpips/loss/__init__.py
ADDED
File without changes
|
sgm/modules/autoencoding/lpips/loss/lpips.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
|
2 |
+
|
3 |
+
from collections import namedtuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torchvision import models
|
8 |
+
|
9 |
+
from ..util import get_ckpt_path
|
10 |
+
|
11 |
+
|
12 |
+
class LPIPS(nn.Module):
|
13 |
+
# Learned perceptual metric
|
14 |
+
def __init__(self, use_dropout=True):
|
15 |
+
super().__init__()
|
16 |
+
self.scaling_layer = ScalingLayer()
|
17 |
+
self.chns = [64, 128, 256, 512, 512] # vg16 features
|
18 |
+
self.net = vgg16(pretrained=True, requires_grad=False)
|
19 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
20 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
21 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
22 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
23 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
24 |
+
self.load_from_pretrained()
|
25 |
+
for param in self.parameters():
|
26 |
+
param.requires_grad = False
|
27 |
+
|
28 |
+
def load_from_pretrained(self, name="vgg_lpips"):
|
29 |
+
ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss")
|
30 |
+
self.load_state_dict(
|
31 |
+
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
|
32 |
+
)
|
33 |
+
print("loaded pretrained LPIPS loss from {}".format(ckpt))
|
34 |
+
|
35 |
+
@classmethod
|
36 |
+
def from_pretrained(cls, name="vgg_lpips"):
|
37 |
+
if name != "vgg_lpips":
|
38 |
+
raise NotImplementedError
|
39 |
+
model = cls()
|
40 |
+
ckpt = get_ckpt_path(name)
|
41 |
+
model.load_state_dict(
|
42 |
+
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
|
43 |
+
)
|
44 |
+
return model
|
45 |
+
|
46 |
+
def forward(self, input, target):
|
47 |
+
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
|
48 |
+
outs0, outs1 = self.net(in0_input), self.net(in1_input)
|
49 |
+
feats0, feats1, diffs = {}, {}, {}
|
50 |
+
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
51 |
+
for kk in range(len(self.chns)):
|
52 |
+
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
|
53 |
+
outs1[kk]
|
54 |
+
)
|
55 |
+
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
56 |
+
|
57 |
+
res = [
|
58 |
+
spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
|
59 |
+
for kk in range(len(self.chns))
|
60 |
+
]
|
61 |
+
val = res[0]
|
62 |
+
for l in range(1, len(self.chns)):
|
63 |
+
val += res[l]
|
64 |
+
return val
|
65 |
+
|
66 |
+
|
67 |
+
class ScalingLayer(nn.Module):
|
68 |
+
def __init__(self):
|
69 |
+
super(ScalingLayer, self).__init__()
|
70 |
+
self.register_buffer(
|
71 |
+
"shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
|
72 |
+
)
|
73 |
+
self.register_buffer(
|
74 |
+
"scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
|
75 |
+
)
|
76 |
+
|
77 |
+
def forward(self, inp):
|
78 |
+
return (inp - self.shift) / self.scale
|
79 |
+
|
80 |
+
|
81 |
+
class NetLinLayer(nn.Module):
|
82 |
+
"""A single linear layer which does a 1x1 conv"""
|
83 |
+
|
84 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
85 |
+
super(NetLinLayer, self).__init__()
|
86 |
+
layers = (
|
87 |
+
[
|
88 |
+
nn.Dropout(),
|
89 |
+
]
|
90 |
+
if (use_dropout)
|
91 |
+
else []
|
92 |
+
)
|
93 |
+
layers += [
|
94 |
+
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
|
95 |
+
]
|
96 |
+
self.model = nn.Sequential(*layers)
|
97 |
+
|
98 |
+
|
99 |
+
class vgg16(torch.nn.Module):
|
100 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
101 |
+
super(vgg16, self).__init__()
|
102 |
+
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
|
103 |
+
self.slice1 = torch.nn.Sequential()
|
104 |
+
self.slice2 = torch.nn.Sequential()
|
105 |
+
self.slice3 = torch.nn.Sequential()
|
106 |
+
self.slice4 = torch.nn.Sequential()
|
107 |
+
self.slice5 = torch.nn.Sequential()
|
108 |
+
self.N_slices = 5
|
109 |
+
for x in range(4):
|
110 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
111 |
+
for x in range(4, 9):
|
112 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
113 |
+
for x in range(9, 16):
|
114 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
115 |
+
for x in range(16, 23):
|
116 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
117 |
+
for x in range(23, 30):
|
118 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
119 |
+
if not requires_grad:
|
120 |
+
for param in self.parameters():
|
121 |
+
param.requires_grad = False
|
122 |
+
|
123 |
+
def forward(self, X):
|
124 |
+
h = self.slice1(X)
|
125 |
+
h_relu1_2 = h
|
126 |
+
h = self.slice2(h)
|
127 |
+
h_relu2_2 = h
|
128 |
+
h = self.slice3(h)
|
129 |
+
h_relu3_3 = h
|
130 |
+
h = self.slice4(h)
|
131 |
+
h_relu4_3 = h
|
132 |
+
h = self.slice5(h)
|
133 |
+
h_relu5_3 = h
|
134 |
+
vgg_outputs = namedtuple(
|
135 |
+
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
|
136 |
+
)
|
137 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
138 |
+
return out
|
139 |
+
|
140 |
+
|
141 |
+
def normalize_tensor(x, eps=1e-10):
|
142 |
+
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
|
143 |
+
return x / (norm_factor + eps)
|
144 |
+
|
145 |
+
|
146 |
+
def spatial_average(x, keepdim=True):
|
147 |
+
return x.mean([2, 3], keepdim=keepdim)
|
sgm/modules/autoencoding/lpips/model/LICENSE
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
|
2 |
+
All rights reserved.
|
3 |
+
|
4 |
+
Redistribution and use in source and binary forms, with or without
|
5 |
+
modification, are permitted provided that the following conditions are met:
|
6 |
+
|
7 |
+
* Redistributions of source code must retain the above copyright notice, this
|
8 |
+
list of conditions and the following disclaimer.
|
9 |
+
|
10 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
11 |
+
this list of conditions and the following disclaimer in the documentation
|
12 |
+
and/or other materials provided with the distribution.
|
13 |
+
|
14 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
15 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
16 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
17 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
18 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
19 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
20 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
21 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
22 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
23 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
24 |
+
|
25 |
+
|
26 |
+
--------------------------- LICENSE FOR pix2pix --------------------------------
|
27 |
+
BSD License
|
28 |
+
|
29 |
+
For pix2pix software
|
30 |
+
Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
|
31 |
+
All rights reserved.
|
32 |
+
|
33 |
+
Redistribution and use in source and binary forms, with or without
|
34 |
+
modification, are permitted provided that the following conditions are met:
|
35 |
+
|
36 |
+
* Redistributions of source code must retain the above copyright notice, this
|
37 |
+
list of conditions and the following disclaimer.
|
38 |
+
|
39 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
40 |
+
this list of conditions and the following disclaimer in the documentation
|
41 |
+
and/or other materials provided with the distribution.
|
42 |
+
|
43 |
+
----------------------------- LICENSE FOR DCGAN --------------------------------
|
44 |
+
BSD License
|
45 |
+
|
46 |
+
For dcgan.torch software
|
47 |
+
|
48 |
+
Copyright (c) 2015, Facebook, Inc. All rights reserved.
|
49 |
+
|
50 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
51 |
+
|
52 |
+
Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
53 |
+
|
54 |
+
Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
55 |
+
|
56 |
+
Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
57 |
+
|
58 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
sgm/modules/autoencoding/lpips/model/__init__.py
ADDED
File without changes
|
sgm/modules/autoencoding/lpips/model/model.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from ..util import ActNorm
|
6 |
+
|
7 |
+
|
8 |
+
def weights_init(m):
|
9 |
+
classname = m.__class__.__name__
|
10 |
+
if classname.find("Conv") != -1:
|
11 |
+
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
12 |
+
elif classname.find("BatchNorm") != -1:
|
13 |
+
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
14 |
+
nn.init.constant_(m.bias.data, 0)
|
15 |
+
|
16 |
+
|
17 |
+
class NLayerDiscriminator(nn.Module):
|
18 |
+
"""Defines a PatchGAN discriminator as in Pix2Pix
|
19 |
+
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
23 |
+
"""Construct a PatchGAN discriminator
|
24 |
+
Parameters:
|
25 |
+
input_nc (int) -- the number of channels in input images
|
26 |
+
ndf (int) -- the number of filters in the last conv layer
|
27 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
28 |
+
norm_layer -- normalization layer
|
29 |
+
"""
|
30 |
+
super(NLayerDiscriminator, self).__init__()
|
31 |
+
if not use_actnorm:
|
32 |
+
norm_layer = nn.BatchNorm2d
|
33 |
+
else:
|
34 |
+
norm_layer = ActNorm
|
35 |
+
if (
|
36 |
+
type(norm_layer) == functools.partial
|
37 |
+
): # no need to use bias as BatchNorm2d has affine parameters
|
38 |
+
use_bias = norm_layer.func != nn.BatchNorm2d
|
39 |
+
else:
|
40 |
+
use_bias = norm_layer != nn.BatchNorm2d
|
41 |
+
|
42 |
+
kw = 4
|
43 |
+
padw = 1
|
44 |
+
sequence = [
|
45 |
+
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
46 |
+
nn.LeakyReLU(0.2, True),
|
47 |
+
]
|
48 |
+
nf_mult = 1
|
49 |
+
nf_mult_prev = 1
|
50 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
51 |
+
nf_mult_prev = nf_mult
|
52 |
+
nf_mult = min(2**n, 8)
|
53 |
+
sequence += [
|
54 |
+
nn.Conv2d(
|
55 |
+
ndf * nf_mult_prev,
|
56 |
+
ndf * nf_mult,
|
57 |
+
kernel_size=kw,
|
58 |
+
stride=2,
|
59 |
+
padding=padw,
|
60 |
+
bias=use_bias,
|
61 |
+
),
|
62 |
+
norm_layer(ndf * nf_mult),
|
63 |
+
nn.LeakyReLU(0.2, True),
|
64 |
+
]
|
65 |
+
|
66 |
+
nf_mult_prev = nf_mult
|
67 |
+
nf_mult = min(2**n_layers, 8)
|
68 |
+
sequence += [
|
69 |
+
nn.Conv2d(
|
70 |
+
ndf * nf_mult_prev,
|
71 |
+
ndf * nf_mult,
|
72 |
+
kernel_size=kw,
|
73 |
+
stride=1,
|
74 |
+
padding=padw,
|
75 |
+
bias=use_bias,
|
76 |
+
),
|
77 |
+
norm_layer(ndf * nf_mult),
|
78 |
+
nn.LeakyReLU(0.2, True),
|
79 |
+
]
|
80 |
+
|
81 |
+
sequence += [
|
82 |
+
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
|
83 |
+
] # output 1 channel prediction map
|
84 |
+
self.main = nn.Sequential(*sequence)
|
85 |
+
|
86 |
+
def forward(self, input):
|
87 |
+
"""Standard forward."""
|
88 |
+
return self.main(input)
|
sgm/modules/autoencoding/lpips/util.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
|
4 |
+
import requests
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
|
10 |
+
|
11 |
+
CKPT_MAP = {"vgg_lpips": "vgg.pth"}
|
12 |
+
|
13 |
+
MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
|
14 |
+
|
15 |
+
|
16 |
+
def download(url, local_path, chunk_size=1024):
|
17 |
+
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
18 |
+
with requests.get(url, stream=True) as r:
|
19 |
+
total_size = int(r.headers.get("content-length", 0))
|
20 |
+
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
21 |
+
with open(local_path, "wb") as f:
|
22 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
23 |
+
if data:
|
24 |
+
f.write(data)
|
25 |
+
pbar.update(chunk_size)
|
26 |
+
|
27 |
+
|
28 |
+
def md5_hash(path):
|
29 |
+
with open(path, "rb") as f:
|
30 |
+
content = f.read()
|
31 |
+
return hashlib.md5(content).hexdigest()
|
32 |
+
|
33 |
+
|
34 |
+
def get_ckpt_path(name, root, check=False):
|
35 |
+
assert name in URL_MAP
|
36 |
+
path = os.path.join(root, CKPT_MAP[name])
|
37 |
+
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
38 |
+
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
39 |
+
download(URL_MAP[name], path)
|
40 |
+
md5 = md5_hash(path)
|
41 |
+
assert md5 == MD5_MAP[name], md5
|
42 |
+
return path
|
43 |
+
|
44 |
+
|
45 |
+
class ActNorm(nn.Module):
|
46 |
+
def __init__(
|
47 |
+
self, num_features, logdet=False, affine=True, allow_reverse_init=False
|
48 |
+
):
|
49 |
+
assert affine
|
50 |
+
super().__init__()
|
51 |
+
self.logdet = logdet
|
52 |
+
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
53 |
+
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
54 |
+
self.allow_reverse_init = allow_reverse_init
|
55 |
+
|
56 |
+
self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
|
57 |
+
|
58 |
+
def initialize(self, input):
|
59 |
+
with torch.no_grad():
|
60 |
+
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
61 |
+
mean = (
|
62 |
+
flatten.mean(1)
|
63 |
+
.unsqueeze(1)
|
64 |
+
.unsqueeze(2)
|
65 |
+
.unsqueeze(3)
|
66 |
+
.permute(1, 0, 2, 3)
|
67 |
+
)
|
68 |
+
std = (
|
69 |
+
flatten.std(1)
|
70 |
+
.unsqueeze(1)
|
71 |
+
.unsqueeze(2)
|
72 |
+
.unsqueeze(3)
|
73 |
+
.permute(1, 0, 2, 3)
|
74 |
+
)
|
75 |
+
|
76 |
+
self.loc.data.copy_(-mean)
|
77 |
+
self.scale.data.copy_(1 / (std + 1e-6))
|
78 |
+
|
79 |
+
def forward(self, input, reverse=False):
|
80 |
+
if reverse:
|
81 |
+
return self.reverse(input)
|
82 |
+
if len(input.shape) == 2:
|
83 |
+
input = input[:, :, None, None]
|
84 |
+
squeeze = True
|
85 |
+
else:
|
86 |
+
squeeze = False
|
87 |
+
|
88 |
+
_, _, height, width = input.shape
|
89 |
+
|
90 |
+
if self.training and self.initialized.item() == 0:
|
91 |
+
self.initialize(input)
|
92 |
+
self.initialized.fill_(1)
|
93 |
+
|
94 |
+
h = self.scale * (input + self.loc)
|
95 |
+
|
96 |
+
if squeeze:
|
97 |
+
h = h.squeeze(-1).squeeze(-1)
|
98 |
+
|
99 |
+
if self.logdet:
|
100 |
+
log_abs = torch.log(torch.abs(self.scale))
|
101 |
+
logdet = height * width * torch.sum(log_abs)
|
102 |
+
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
103 |
+
return h, logdet
|
104 |
+
|
105 |
+
return h
|
106 |
+
|
107 |
+
def reverse(self, output):
|
108 |
+
if self.training and self.initialized.item() == 0:
|
109 |
+
if not self.allow_reverse_init:
|
110 |
+
raise RuntimeError(
|
111 |
+
"Initializing ActNorm in reverse direction is "
|
112 |
+
"disabled by default. Use allow_reverse_init=True to enable."
|
113 |
+
)
|
114 |
+
else:
|
115 |
+
self.initialize(output)
|
116 |
+
self.initialized.fill_(1)
|
117 |
+
|
118 |
+
if len(output.shape) == 2:
|
119 |
+
output = output[:, :, None, None]
|
120 |
+
squeeze = True
|
121 |
+
else:
|
122 |
+
squeeze = False
|
123 |
+
|
124 |
+
h = output / self.scale - self.loc
|
125 |
+
|
126 |
+
if squeeze:
|
127 |
+
h = h.squeeze(-1).squeeze(-1)
|
128 |
+
return h
|
sgm/modules/autoencoding/lpips/vqperceptual.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
def hinge_d_loss(logits_real, logits_fake):
|
6 |
+
loss_real = torch.mean(F.relu(1.0 - logits_real))
|
7 |
+
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
|
8 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
9 |
+
return d_loss
|
10 |
+
|
11 |
+
|
12 |
+
def vanilla_d_loss(logits_real, logits_fake):
|
13 |
+
d_loss = 0.5 * (
|
14 |
+
torch.mean(torch.nn.functional.softplus(-logits_real))
|
15 |
+
+ torch.mean(torch.nn.functional.softplus(logits_fake))
|
16 |
+
)
|
17 |
+
return d_loss
|
sgm/modules/autoencoding/regularizers/__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from typing import Any, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from ....modules.distributions.distributions import \
|
9 |
+
DiagonalGaussianDistribution
|
10 |
+
from .base import AbstractRegularizer
|
11 |
+
|
12 |
+
|
13 |
+
class DiagonalGaussianRegularizer(AbstractRegularizer):
|
14 |
+
def __init__(self, sample: bool = True):
|
15 |
+
super().__init__()
|
16 |
+
self.sample = sample
|
17 |
+
|
18 |
+
def get_trainable_parameters(self) -> Any:
|
19 |
+
yield from ()
|
20 |
+
|
21 |
+
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
22 |
+
log = dict()
|
23 |
+
posterior = DiagonalGaussianDistribution(z)
|
24 |
+
if self.sample:
|
25 |
+
z = posterior.sample()
|
26 |
+
else:
|
27 |
+
z = posterior.mode()
|
28 |
+
kl_loss = posterior.kl()
|
29 |
+
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
30 |
+
log["kl_loss"] = kl_loss
|
31 |
+
return z, log
|
sgm/modules/autoencoding/regularizers/base.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from typing import Any, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
|
9 |
+
class AbstractRegularizer(nn.Module):
|
10 |
+
def __init__(self):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
14 |
+
raise NotImplementedError()
|
15 |
+
|
16 |
+
@abstractmethod
|
17 |
+
def get_trainable_parameters(self) -> Any:
|
18 |
+
raise NotImplementedError()
|
19 |
+
|
20 |
+
|
21 |
+
class IdentityRegularizer(AbstractRegularizer):
|
22 |
+
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
23 |
+
return z, dict()
|
24 |
+
|
25 |
+
def get_trainable_parameters(self) -> Any:
|
26 |
+
yield from ()
|
27 |
+
|
28 |
+
|
29 |
+
def measure_perplexity(
|
30 |
+
predicted_indices: torch.Tensor, num_centroids: int
|
31 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
32 |
+
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
|
33 |
+
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
|
34 |
+
encodings = (
|
35 |
+
F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
|
36 |
+
)
|
37 |
+
avg_probs = encodings.mean(0)
|
38 |
+
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
|
39 |
+
cluster_use = torch.sum(avg_probs > 0)
|
40 |
+
return perplexity, cluster_use
|