slothfulxtx commited on
Commit
ca2145e
·
1 Parent(s): 774e213
.gitignore ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
163
+
164
+ # Vscode settings
165
+ .vscode/
166
+
167
+ # Temporal files
168
+ /tmp
169
+ tmp*
170
+
171
+ # Workspace
172
+ /workspace
173
+
174
+ # running scripts
175
+ /*.sh
176
+
177
+ # pretrained
178
+ /pretrained_models
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "third_party/moge"]
2
+ path = third_party/moge
3
+ url = https://github.com/microsoft/MoGe.git
app.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import uuid
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import spaces
8
+ import gradio as gr
9
+ import torch
10
+ from decord import cpu, VideoReader
11
+ from diffusers.training_utils import set_seed
12
+ import torch.nn.functional as F
13
+ import imageio
14
+ from kornia.filters import canny
15
+ from kornia.morphology import dilation
16
+
17
+ from third_party import MoGe
18
+ from geometrycrafter import (
19
+ GeometryCrafterDiffPipeline,
20
+ GeometryCrafterDetermPipeline,
21
+ PMapAutoencoderKLTemporalDecoder,
22
+ UNetSpatioTemporalConditionModelVid2vid
23
+ )
24
+
25
+ from utils.glb_utils import pmap_to_glb
26
+ from utils.disp_utils import pmap_to_disp
27
+
28
+ examples = [
29
+ # process_length: int,
30
+ # max_res: int,
31
+ # num_inference_steps: int,
32
+ # guidance_scale: float,
33
+ # window_size: int,
34
+ # decode_chunk_size: int,
35
+ # overlap: int,
36
+ ["examples/video1.mp4", 60, 640, 5, 1.0, 110, 8, 25],
37
+ ["examples/video2.mp4", 60, 640, 5, 1.0, 110, 8, 25],
38
+ ["examples/video3.mp4", 60, 640, 5, 1.0, 110, 8, 25],
39
+ ["examples/video4.mp4", 60, 640, 5, 1.0, 110, 8, 25],
40
+ ]
41
+
42
+ model_type = 'diff'
43
+ cache_dir = 'workspace/cache'
44
+
45
+ unet = UNetSpatioTemporalConditionModelVid2vid.from_pretrained(
46
+ 'TencentARC/GeometryCrafter',
47
+ subfolder='unet_diff' if model_type == 'diff' else 'unet_determ',
48
+ low_cpu_mem_usage=True,
49
+ torch_dtype=torch.float16,
50
+ cache_dir=cache_dir
51
+ ).requires_grad_(False).to("cuda", dtype=torch.float16)
52
+ point_map_vae = PMapAutoencoderKLTemporalDecoder.from_pretrained(
53
+ 'TencentARC/GeometryCrafter',
54
+ subfolder='point_map_vae',
55
+ low_cpu_mem_usage=True,
56
+ torch_dtype=torch.float32,
57
+ cache_dir=cache_dir
58
+ ).requires_grad_(False).to("cuda", dtype=torch.float32)
59
+ prior_model = MoGe(
60
+ cache_dir=cache_dir,
61
+ ).requires_grad_(False).to('cuda', dtype=torch.float32)
62
+ if model_type == 'diff':
63
+ pipe = GeometryCrafterDiffPipeline.from_pretrained(
64
+ 'stabilityai/stable-video-diffusion-img2vid-xt',
65
+ unet=unet,
66
+ torch_dtype=torch.float16,
67
+ variant="fp16",
68
+ cache_dir=cache_dir
69
+ ).to("cuda")
70
+ else:
71
+ pipe = GeometryCrafterDetermPipeline.from_pretrained(
72
+ 'stabilityai/stable-video-diffusion-img2vid-xt',
73
+ unet=unet,
74
+ torch_dtype=torch.float16,
75
+ variant="fp16",
76
+ cache_dir=cache_dir
77
+ ).to("cuda")
78
+
79
+ try:
80
+ pipe.enable_xformers_memory_efficient_attention()
81
+ except Exception as e:
82
+ print(e)
83
+ print("Xformers is not enabled")
84
+ # bugs at https://github.com/continue-revolution/sd-webui-animatediff/issues/101
85
+ # pipe.enable_xformers_memory_efficient_attention()
86
+ pipe.enable_attention_slicing()
87
+
88
+ mesh_seqs = []
89
+ frame_seqs = []
90
+ cur_mesh_idx = None
91
+
92
+ def read_video_frames(video_path, process_length, max_res):
93
+ print("==> processing video: ", video_path)
94
+ vid = VideoReader(video_path, ctx=cpu(0))
95
+ fps = vid.get_avg_fps()
96
+ print("==> original video shape: ", (len(vid), *vid.get_batch([0]).shape[1:]))
97
+ original_height, original_width = vid.get_batch([0]).shape[1:3]
98
+ if max(original_height, original_width) > max_res:
99
+ scale = max_res / max(original_height, original_width)
100
+ original_height, original_width = round(original_height * scale), round(original_width * scale)
101
+ else:
102
+ scale = 1.0
103
+ height = round(original_height * scale / 64) * 64
104
+ width = round(original_width * scale / 64) * 64
105
+ vid = VideoReader(video_path, ctx=cpu(0), width=original_width, height=original_height)
106
+ frames_idx = list(range(0, min(len(vid), process_length) if process_length != -1 else len(vid)))
107
+ print(
108
+ f"==> final processing shape: {len(frames_idx), *vid.get_batch([0]).shape[1:]}"
109
+ )
110
+ frames = vid.get_batch(frames_idx).asnumpy().astype("float32") / 255.0
111
+ return frames, height, width, fps
112
+
113
+
114
+ def compute_edge_mask(depth: torch.Tensor, edge_dilation_radius: int):
115
+ magnitude, edges = canny(depth[None, None, :, :], low_threshold=0.4, high_threshold=0.5)
116
+ magnitude = magnitude[0, 0]
117
+ edges = edges[0, 0]
118
+ mask = (edges > 0).float()
119
+ mask = dilation(mask[None, None, :, :], torch.ones((edge_dilation_radius,edge_dilation_radius), device=mask.device))
120
+ return mask[0, 0] > 0.5
121
+
122
+ @spaces.GPU(duration=120)
123
+ @torch.inference_mode()
124
+ def infer_geometry(
125
+ video: str,
126
+ process_length: int,
127
+ max_res: int,
128
+ num_inference_steps: int,
129
+ guidance_scale: float,
130
+ window_size: int,
131
+ decode_chunk_size: int,
132
+ overlap: int,
133
+ downsample_ratio: float = 1.0, # downsample pcd for visualization
134
+ num_sample_frames: int =8, # downsample frames for visualization
135
+ remove_edge: bool = True, # remove edge for visualization
136
+ save_folder: str = os.path.join('workspace', 'GeometryCrafterApp'),
137
+ ):
138
+ try:
139
+ global cur_mesh_idx, mesh_seqs, frame_seqs
140
+ run_id = str(uuid.uuid4())
141
+ set_seed(42)
142
+ pipe.enable_xformers_memory_efficient_attention()
143
+
144
+ frames, height, width, fps = read_video_frames(video, process_length, max_res)
145
+ aspect_ratio = width / height
146
+ assert 0.5 <= aspect_ratio and aspect_ratio <= 2.0
147
+ frames_tensor = torch.tensor(frames.astype("float32"), device='cuda').float().permute(0, 3, 1, 2)
148
+ window_size = min(window_size, len(frames))
149
+ if window_size == len(frames):
150
+ overlap = 0
151
+
152
+ point_maps, valid_masks = pipe(
153
+ frames_tensor,
154
+ point_map_vae,
155
+ prior_model,
156
+ height=height,
157
+ width=width,
158
+ num_inference_steps=num_inference_steps,
159
+ guidance_scale=guidance_scale,
160
+ window_size=window_size,
161
+ decode_chunk_size=decode_chunk_size,
162
+ overlap=overlap,
163
+ force_projection=True,
164
+ force_fixed_focal=True,
165
+ )
166
+ frames_tensor = frames_tensor.cpu()
167
+ point_maps = point_maps.cpu()
168
+ valid_masks = valid_masks.cpu()
169
+
170
+ gc.collect()
171
+ torch.cuda.empty_cache()
172
+ output_npz_path = Path(save_folder, run_id, f'point_maps.npz')
173
+ output_npz_path.parent.mkdir(exist_ok=True)
174
+
175
+
176
+ np.savez_compressed(
177
+ output_npz_path,
178
+ point_map=point_maps.cpu().numpy().astype(np.float16),
179
+ valid_mask=valid_masks.cpu().numpy().astype(np.bool_)
180
+ )
181
+
182
+ output_disp_path = Path(save_folder, run_id, f'disp.mp4')
183
+ output_disp_path.parent.mkdir(exist_ok=True)
184
+
185
+ colored_disp = pmap_to_disp(point_maps, valid_masks)
186
+ imageio.mimsave(
187
+ output_disp_path, (colored_disp*255).cpu().numpy().astype(np.uint8), fps=fps, macro_block_size=1)
188
+
189
+
190
+ # downsample for visualization
191
+ if downsample_ratio > 1.0:
192
+ H, W = point_maps.shape[1:3]
193
+ H, W = round(H / downsample_ratio), round(W / downsample_ratio)
194
+ point_maps = F.interpolate(point_maps.permute(0,3,1,2), (H, W)).permute(0,2,3,1)
195
+ frames = F.interpolate(frames_tensor, (H, W)).permute(0,2,3,1)
196
+ valid_masks = F.interpolate(valid_masks.float()[:, None], (H, W))[:, 0] > 0.5
197
+ else:
198
+ H, W = point_maps.shape[1:3]
199
+ frames = frames_tensor.permute(0,2,3,1)
200
+
201
+
202
+ if remove_edge:
203
+ for i in range(len(valid_masks)):
204
+ edge_mask = compute_edge_mask(point_maps[i, :, :, 2], 3)
205
+ valid_masks[i] = valid_masks[i] & (~edge_mask)
206
+
207
+ indices = np.linspace(0, len(point_maps)-1, num_sample_frames)
208
+ indices = np.round(indices).astype(np.int32)
209
+
210
+ mesh_seqs.clear()
211
+ cur_mesh_idx = None
212
+
213
+ for index in indices:
214
+
215
+ valid_mask = valid_masks[index].cpu().numpy()
216
+ point_map = point_maps[index].cpu().numpy()
217
+ frame = frames[index].cpu().numpy()
218
+ output_glb_path = Path(save_folder, run_id, f'{index:04}.glb')
219
+ output_glb_path.parent.mkdir(exist_ok=True)
220
+ glbscene = pmap_to_glb(point_map, valid_mask, frame)
221
+ glbscene.export(file_obj=output_glb_path)
222
+ mesh_seqs.append(output_glb_path)
223
+ frame_seqs.append(index)
224
+
225
+ cur_mesh_idx = 0
226
+
227
+ gc.collect()
228
+ torch.cuda.empty_cache()
229
+
230
+ return [
231
+ gr.Model3D(value=mesh_seqs[cur_mesh_idx], label=f"Frame: {frame_seqs[cur_mesh_idx]}"),
232
+ gr.Video(value=output_disp_path, label="Disparity", interactive=False),
233
+ gr.DownloadButton("Download Npz File", value=output_npz_path, visible=True)
234
+ ]
235
+ except Exception as e:
236
+ mesh_seqs.clear()
237
+ frame_seqs.clear()
238
+ cur_mesh_idx = None
239
+ gc.collect()
240
+ torch.cuda.empty_cache()
241
+ raise gr.Error(str(e))
242
+ # return [
243
+ # gr.Model3D(
244
+ # label="Point Map",
245
+ # clear_color=[1.0, 1.0, 1.0, 1.0],
246
+ # interactive=False
247
+ # ),
248
+ # gr.Video(label="Disparity", interactive=False),
249
+ # gr.DownloadButton("Download Npz File", visible=False)
250
+ # ]
251
+
252
+ def goto_prev_frame():
253
+ global cur_mesh_idx, mesh_seqs, frame_seqs
254
+ if cur_mesh_idx is not None and len(mesh_seqs) > 0:
255
+ if cur_mesh_idx > 0:
256
+ cur_mesh_idx -= 1
257
+ return gr.Model3D(value=mesh_seqs[cur_mesh_idx], label=f"Frame: {frame_seqs[cur_mesh_idx]}")
258
+
259
+
260
+ def goto_next_frame():
261
+ global cur_mesh_idx, mesh_seqs, frame_seqs
262
+ if cur_mesh_idx is not None and len(mesh_seqs) > 0:
263
+ if cur_mesh_idx < len(mesh_seqs)-1:
264
+ cur_mesh_idx += 1
265
+ return gr.Model3D(value=mesh_seqs[cur_mesh_idx], label=f"Frame: {frame_seqs[cur_mesh_idx]}")
266
+
267
+ def download_file():
268
+ return gr.DownloadButton(visible=False)
269
+
270
+ def build_demo():
271
+ with gr.Blocks(analytics_enabled=False) as gradio_demo:
272
+ gr.Markdown(
273
+ """
274
+ <div align='center'>
275
+ <h1> GeometryCrafter: Consistent Geometry Estimation for Open-world Videos with Diffusion Priors </h1> \
276
+ <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
277
+ <a href='https://scholar.google.com/citations?user=zHp0rMIAAAAJ'>Tian-Xing Xu</a>, \
278
+ <a href='https://scholar.google.com/citations?user=qgdesEcAAAAJ'>Xiangjun Gao</a>, \
279
+ <a href='https://wbhu.github.io'>Wenbo Hu</a>, \
280
+ <a href='https://xiaoyu258.github.io/'>Xiaoyu Li</a>, \
281
+ <a href='https://scholar.google.com/citations?user=AWtV-EQAAAAJ'>Song-Hai Zhang</a>,\
282
+ <a href='https://scholar.google.com/citations?user=4oXBp9UAAAAJ'>Ying Shan</a>\
283
+ </h2> \
284
+ <span style='font-size:18px'>If you find GeometryCrafter useful, please help ⭐ the \
285
+ <a style='font-size:18px' href='https://github.com/TencentARC/GeometryCrafter/'>[Github Repo]</a>\
286
+ , which is important to Open-Source projects. Thanks!\
287
+ <a style='font-size:18px' href='https://arxiv.org'> [ArXivTODO] </a>\
288
+ <a style='font-size:18px' href='https://geometrycrafter.github.io'> [Project Page] </a>
289
+ </span>
290
+ </div>
291
+ """
292
+ )
293
+
294
+ with gr.Row(equal_height=True):
295
+ with gr.Column(scale=1):
296
+ input_video = gr.Video(
297
+ label="Input Video",
298
+ sources=['upload']
299
+ )
300
+ with gr.Row(equal_height=False):
301
+ with gr.Accordion("Advanced Settings", open=False):
302
+ process_length = gr.Slider(
303
+ label="process length",
304
+ minimum=-1,
305
+ maximum=280,
306
+ value=110,
307
+ step=1,
308
+ )
309
+ max_res = gr.Slider(
310
+ label="max resolution",
311
+ minimum=512,
312
+ maximum=2048,
313
+ value=1024,
314
+ step=64,
315
+ )
316
+ num_denoising_steps = gr.Slider(
317
+ label="num denoising steps",
318
+ minimum=1,
319
+ maximum=25,
320
+ value=5,
321
+ step=1,
322
+ )
323
+ guidance_scale = gr.Slider(
324
+ label="cfg scale",
325
+ minimum=1.0,
326
+ maximum=1.2,
327
+ value=1.0,
328
+ step=0.1,
329
+ )
330
+ window_size = gr.Slider(
331
+ label="shift window size",
332
+ minimum=10,
333
+ maximum=110,
334
+ value=110,
335
+ step=10,
336
+ )
337
+ decode_chunk_size = gr.Slider(
338
+ label="decode chunk size",
339
+ minimum=1,
340
+ maximum=16,
341
+ value=6,
342
+ step=1,
343
+ )
344
+ overlap = gr.Slider(
345
+ label="overlap",
346
+ minimum=1,
347
+ maximum=50,
348
+ value=25,
349
+ step=1,
350
+ )
351
+ generate_btn = gr.Button("Generate")
352
+
353
+ with gr.Column(scale=1):
354
+ output_point_maps = gr.Model3D(
355
+ label="Point Map",
356
+ clear_color=[1.0, 1.0, 1.0, 1.0],
357
+ # display_mode="solid"
358
+ interactive=False
359
+ )
360
+ with gr.Row():
361
+ prev_btn = gr.Button("Prev")
362
+ next_btn = gr.Button("Next")
363
+
364
+ with gr.Column(scale=1):
365
+ output_disp_video = gr.Video(
366
+ label="Disparity",
367
+ interactive=False
368
+ )
369
+ download_btn = gr.DownloadButton("Download Npz File", visible=False)
370
+
371
+ gr.Examples(
372
+ examples=examples,
373
+ fn=infer_geometry,
374
+ inputs=[
375
+ input_video,
376
+ process_length,
377
+ max_res,
378
+ num_denoising_steps,
379
+ guidance_scale,
380
+ window_size,
381
+ decode_chunk_size,
382
+ overlap,
383
+ ],
384
+ outputs=[output_point_maps, output_disp_video, download_btn],
385
+ # cache_examples="lazy",
386
+ )
387
+ gr.Markdown(
388
+ """
389
+ <span style='font-size:18px'>Note:
390
+ For time quota consideration, we set the default parameters to be more efficient here,
391
+ with a trade-off of shorter video length and slightly lower quality.
392
+ You may adjust the parameters according to our
393
+ <a style='font-size:18px' href='https://github.com/TencentARC/GeometryCrafter/'>[Github Repo]</a>
394
+ for better results if you have enough time quota. We only provide a simplified visualization
395
+ script in this page due to the lack of support for point cloud sequences. You can download
396
+ the npz file and open it with Viser backend in our repo for better visualization.
397
+ </span>
398
+ """
399
+ )
400
+
401
+ generate_btn.click(
402
+ fn=infer_geometry,
403
+ inputs=[
404
+ input_video,
405
+ process_length,
406
+ max_res,
407
+ num_denoising_steps,
408
+ guidance_scale,
409
+ window_size,
410
+ decode_chunk_size,
411
+ overlap,
412
+ ],
413
+ outputs=[output_point_maps, output_disp_video, download_btn],
414
+ )
415
+
416
+ prev_btn.click(
417
+ fn=goto_prev_frame,
418
+ outputs=output_point_maps,
419
+ )
420
+ next_btn.click(
421
+ fn=goto_next_frame,
422
+ outputs=output_point_maps,
423
+ )
424
+ download_btn.click(
425
+ fn=download_file,
426
+ outputs=download_btn
427
+ )
428
+
429
+ return gradio_demo
430
+
431
+
432
+ if __name__ == "__main__":
433
+ demo = build_demo()
434
+ demo.queue()
435
+ demo.launch(server_name="0.0.0.0", server_port=12345, debug=True, share=False)
436
+ # demo.launch(share=True)
geometrycrafter/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .pmap_vae import PMapAutoencoderKLTemporalDecoder
2
+ from .unet import UNetSpatioTemporalConditionModelVid2vid
3
+ from .diff_ppl import GeometryCrafterDiffPipeline
4
+ from .determ_ppl import GeometryCrafterDetermPipeline
geometrycrafter/determ_ppl.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+ import gc
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
9
+ _resize_with_antialiasing,
10
+ StableVideoDiffusionPipeline,
11
+ )
12
+ from diffusers.utils import logging
13
+ from kornia.utils import create_meshgrid
14
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
15
+
16
+
17
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18
+
19
+ @torch.no_grad()
20
+ def normalize_point_map(point_map, valid_mask):
21
+ # T,H,W,3 T,H,W
22
+ norm_factor = (point_map[..., 2] * valid_mask.float()).mean() / (valid_mask.float().mean() + 1e-8)
23
+ norm_factor = norm_factor.clip(min=1e-3)
24
+ return point_map / norm_factor
25
+
26
+ def point_map_xy2intrinsic_map(point_map_xy):
27
+ # *,h,w,2
28
+ height, width = point_map_xy.shape[-3], point_map_xy.shape[-2]
29
+ assert height % 2 == 0
30
+ assert width % 2 == 0
31
+ mesh_grid = create_meshgrid(
32
+ height=height,
33
+ width=width,
34
+ normalized_coordinates=True,
35
+ device=point_map_xy.device,
36
+ dtype=point_map_xy.dtype
37
+ )[0] # h,w,2
38
+ assert mesh_grid.abs().min() > 1e-4
39
+ # *,h,w,2
40
+ mesh_grid = mesh_grid.expand_as(point_map_xy)
41
+ nc = point_map_xy.mean(dim=-2).mean(dim=-2) # *, 2
42
+ nc_map = nc[..., None, None, :].expand_as(point_map_xy)
43
+ nf = ((point_map_xy - nc_map) / mesh_grid).mean(dim=-2).mean(dim=-2)
44
+ nf_map = nf[..., None, None, :].expand_as(point_map_xy)
45
+ # print((mesh_grid * nf_map + nc_map - point_map_xy).abs().max())
46
+
47
+ return torch.cat([nc_map, nf_map], dim=-1)
48
+
49
+ def robust_min_max(tensor, quantile=0.99):
50
+ T, H, W = tensor.shape
51
+ min_vals = []
52
+ max_vals = []
53
+ for i in range(T):
54
+ min_vals.append(torch.quantile(tensor[i], q=1-quantile, interpolation='nearest').item())
55
+ max_vals.append(torch.quantile(tensor[i], q=quantile, interpolation='nearest').item())
56
+ return min(min_vals), max(max_vals)
57
+
58
+ class GeometryCrafterDetermPipeline(StableVideoDiffusionPipeline):
59
+
60
+ @torch.inference_mode()
61
+ def encode_video(
62
+ self,
63
+ video: torch.Tensor,
64
+ chunk_size: int = 14,
65
+ ) -> torch.Tensor:
66
+ """
67
+ :param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
68
+ :param chunk_size: the chunk size to encode video
69
+ :return: image_embeddings in shape of [b, 1024]
70
+ """
71
+
72
+ video_224 = _resize_with_antialiasing(video.float(), (224, 224))
73
+ video_224 = (video_224 + 1.0) / 2.0 # [-1, 1] -> [0, 1]
74
+ embeddings = []
75
+ for i in range(0, video_224.shape[0], chunk_size):
76
+ emb = self.feature_extractor(
77
+ images=video_224[i : i + chunk_size],
78
+ do_normalize=True,
79
+ do_center_crop=False,
80
+ do_resize=False,
81
+ do_rescale=False,
82
+ return_tensors="pt",
83
+ ).pixel_values.to(video.device, dtype=video.dtype)
84
+ embeddings.append(self.image_encoder(emb).image_embeds) # [b, 1024]
85
+
86
+ embeddings = torch.cat(embeddings, dim=0) # [t, 1024]
87
+ return embeddings
88
+
89
+ @torch.inference_mode()
90
+ def encode_vae_video(
91
+ self,
92
+ video: torch.Tensor,
93
+ chunk_size: int = 14,
94
+ ):
95
+ """
96
+ :param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
97
+ :param chunk_size: the chunk size to encode video
98
+ :return: vae latents in shape of [b, c, h, w]
99
+ """
100
+ video_latents = []
101
+ for i in range(0, video.shape[0], chunk_size):
102
+ video_latents.append(
103
+ self.vae.encode(video[i : i + chunk_size]).latent_dist.mode()
104
+ )
105
+ video_latents = torch.cat(video_latents, dim=0)
106
+ return video_latents
107
+
108
+
109
+ @torch.inference_mode()
110
+ def produce_priors(self, prior_model, frame, chunk_size=8):
111
+ T, _, H, W = frame.shape
112
+ frame = (frame + 1) / 2
113
+ pred_point_maps = []
114
+ pred_masks = []
115
+ for i in range(0, len(frame), chunk_size):
116
+ pred_p, pred_m = prior_model.forward_image(frame[i:i+chunk_size])
117
+ pred_point_maps.append(pred_p)
118
+ pred_masks.append(pred_m)
119
+ pred_point_maps = torch.cat(pred_point_maps, dim=0)
120
+ pred_masks = torch.cat(pred_masks, dim=0)
121
+
122
+ pred_masks = pred_masks.float() * 2 - 1
123
+
124
+ # T,H,W,3 T,H,W
125
+ pred_point_maps = normalize_point_map(pred_point_maps, pred_masks > 0)
126
+
127
+ pred_disps = 1.0 / pred_point_maps[..., 2].clamp_min(1e-3)
128
+ pred_disps = pred_disps * (pred_masks > 0)
129
+ min_disparity, max_disparity = robust_min_max(pred_disps)
130
+ pred_disps = ((pred_disps - min_disparity) / (max_disparity - min_disparity+1e-4)).clamp(0, 1)
131
+ pred_disps = pred_disps * 2 - 1
132
+
133
+ pred_point_maps[..., :2] = pred_point_maps[..., :2] / (pred_point_maps[..., 2:3] + 1e-7)
134
+ pred_point_maps[..., 2] = torch.log(pred_point_maps[..., 2] + 1e-7) * (pred_masks > 0) # [x/z, y/z, log(z)]
135
+
136
+ pred_intr_maps = point_map_xy2intrinsic_map(pred_point_maps[..., :2]).permute(0,3,1,2) # T,H,W,2
137
+ pred_point_maps = pred_point_maps.permute(0,3,1,2)
138
+
139
+ return pred_disps, pred_masks, pred_point_maps, pred_intr_maps
140
+
141
+ @torch.inference_mode()
142
+ def encode_point_map(self, point_map_vae, disparity, valid_mask, point_map, intrinsic_map, chunk_size=8):
143
+ T, _, H, W = point_map.shape
144
+ latents = []
145
+
146
+ psedo_image = disparity[:, None].repeat(1,3,1,1)
147
+ intrinsic_map = torch.norm(intrinsic_map[:, 2:4], p=2, dim=1, keepdim=False)
148
+
149
+ for i in range(0, T, chunk_size):
150
+ latent_dist = self.vae.encode(psedo_image[i : i + chunk_size].to(self.vae.dtype)).latent_dist
151
+ latent_dist = point_map_vae.encode(
152
+ torch.cat([
153
+ intrinsic_map[i:i+chunk_size, None],
154
+ point_map[i:i+chunk_size, 2:3],
155
+ disparity[i:i+chunk_size, None],
156
+ valid_mask[i:i+chunk_size, None]], dim=1),
157
+ latent_dist
158
+ )
159
+ if isinstance(latent_dist, DiagonalGaussianDistribution):
160
+ latent = latent_dist.mode()
161
+ else:
162
+ latent = latent_dist
163
+
164
+ assert isinstance(latent, torch.Tensor)
165
+ latents.append(latent)
166
+ latents = torch.cat(latents, dim=0)
167
+ latents = latents * self.vae.config.scaling_factor
168
+ return latents
169
+
170
+ @torch.no_grad()
171
+ def decode_point_map(self, point_map_vae, latents, chunk_size=8, force_projection=True, force_fixed_focal=True, use_extract_interp=False, need_resize=False, height=None, width=None):
172
+ T = latents.shape[0]
173
+ rec_intrinsic_maps = []
174
+ rec_depth_maps = []
175
+ rec_valid_masks = []
176
+ for i in range(0, T, chunk_size):
177
+ lat = latents[i:i+chunk_size]
178
+ rec_imap, rec_dmap, rec_vmask = point_map_vae.decode(
179
+ lat,
180
+ num_frames=lat.shape[0],
181
+ )
182
+ rec_intrinsic_maps.append(rec_imap)
183
+ rec_depth_maps.append(rec_dmap)
184
+ rec_valid_masks.append(rec_vmask)
185
+
186
+ rec_intrinsic_maps = torch.cat(rec_intrinsic_maps, dim=0)
187
+ rec_depth_maps = torch.cat(rec_depth_maps, dim=0)
188
+ rec_valid_masks = torch.cat(rec_valid_masks, dim=0)
189
+
190
+ if need_resize:
191
+ rec_depth_maps = F.interpolate(rec_depth_maps, (height, width), mode='nearest-exact') if use_extract_interp else F.interpolate(rec_depth_maps, (height, width), mode='bilinear', align_corners=False)
192
+ rec_valid_masks = F.interpolate(rec_valid_masks, (height, width), mode='nearest-exact') if use_extract_interp else F.interpolate(rec_valid_masks, (height, width), mode='bilinear', align_corners=False)
193
+ rec_intrinsic_maps = F.interpolate(rec_intrinsic_maps, (height, width), mode='bilinear', align_corners=False)
194
+
195
+ H, W = rec_intrinsic_maps.shape[-2], rec_intrinsic_maps.shape[-1]
196
+ mesh_grid = create_meshgrid(
197
+ H, W,
198
+ normalized_coordinates=True
199
+ ).to(rec_intrinsic_maps.device, rec_intrinsic_maps.dtype, non_blocking=True)
200
+ # 1,h,w,2
201
+ rec_intrinsic_maps = torch.cat([rec_intrinsic_maps * W / np.sqrt(W**2+H**2), rec_intrinsic_maps * H / np.sqrt(W**2+H**2)], dim=1) # t,2,h,w
202
+ mesh_grid = mesh_grid.permute(0,3,1,2)
203
+ rec_valid_masks = rec_valid_masks.squeeze(1) > 0
204
+
205
+ if force_projection:
206
+ if force_fixed_focal:
207
+ nfx = (rec_intrinsic_maps[:, 0, :, :] * rec_valid_masks.float()).mean() / (rec_valid_masks.float().mean() + 1e-4)
208
+ nfy = (rec_intrinsic_maps[:, 1, :, :] * rec_valid_masks.float()).mean() / (rec_valid_masks.float().mean() + 1e-4)
209
+ rec_intrinsic_maps = torch.tensor([nfx, nfy], device=rec_intrinsic_maps.device)[None, :, None, None].repeat(T, 1, 1, 1)
210
+ else:
211
+ nfx = (rec_intrinsic_maps[:, 0, :, :] * rec_valid_masks.float()).mean(dim=[-1, -2]) / (rec_valid_masks.float().mean(dim=[-1, -2]) + 1e-4)
212
+ nfy = (rec_intrinsic_maps[:, 1, :, :] * rec_valid_masks.float()).mean(dim=[-1, -2]) / (rec_valid_masks.float().mean(dim=[-1, -2]) + 1e-4)
213
+ rec_intrinsic_maps = torch.stack([nfx, nfy], dim=-1)[:, :, None, None]
214
+ # t,2,1,1
215
+
216
+ rec_point_maps = torch.cat([rec_intrinsic_maps * mesh_grid, rec_depth_maps], dim=1).permute(0,2,3,1)
217
+ xy, z = rec_point_maps.split([2, 1], dim=-1)
218
+ z = torch.clamp_max(z, 10) # for numerical stability
219
+ z = torch.exp(z)
220
+ rec_point_maps = torch.cat([xy * z, z], dim=-1)
221
+
222
+ return rec_point_maps, rec_valid_masks
223
+
224
+
225
+ @torch.no_grad()
226
+ def __call__(
227
+ self,
228
+ video: Union[np.ndarray, torch.Tensor],
229
+ point_map_vae,
230
+ prior_model,
231
+ height: int = 576,
232
+ width: int = 1024,
233
+ window_size: Optional[int] = 14,
234
+ noise_aug_strength: float = 0.02,
235
+ decode_chunk_size: Optional[int] = None,
236
+ overlap: int = 4,
237
+ force_projection: bool = True,
238
+ force_fixed_focal: bool = True,
239
+ use_extract_interp: bool = False,
240
+ track_time: bool = False,
241
+ **kwargs
242
+ ):
243
+ # video: in shape [t, h, w, c] if np.ndarray or [t, c, h, w] if torch.Tensor, in range [0, 1]
244
+
245
+ # 0. Define height and width for preprocessing
246
+
247
+ if isinstance(video, np.ndarray):
248
+ video = torch.from_numpy(video.transpose(0, 3, 1, 2))
249
+ else:
250
+ assert isinstance(video, torch.Tensor)
251
+
252
+ height = height or video.shape[-2]
253
+ width = width or video.shape[-1]
254
+ original_height = video.shape[-2]
255
+ original_width = video.shape[-1]
256
+ num_frames = video.shape[0]
257
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 8
258
+ if num_frames <= window_size:
259
+ window_size = num_frames
260
+ overlap = 0
261
+ stride = window_size - overlap
262
+
263
+ # 1. Check inputs. Raise error if not correct
264
+ assert height % 64 == 0 and width % 64 == 0
265
+ if original_height != height or original_width != width:
266
+ need_resize = True
267
+ else:
268
+ need_resize = False
269
+
270
+ # 2. Define call parameters
271
+ batch_size = 1
272
+ device = self._execution_device
273
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
274
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
275
+ # corresponds to doing no classifier free guidance.
276
+ self._guidance_scale = 1.0
277
+
278
+ if track_time:
279
+ start_event = torch.cuda.Event(enable_timing=True)
280
+ prior_event = torch.cuda.Event(enable_timing=True)
281
+ encode_event = torch.cuda.Event(enable_timing=True)
282
+ denoise_event = torch.cuda.Event(enable_timing=True)
283
+ decode_event = torch.cuda.Event(enable_timing=True)
284
+ start_event.record()
285
+
286
+ # 3. Compute prior latents under original resolutions
287
+ pred_disparity, pred_valid_mask, pred_point_map, pred_intrinsic_map = self.produce_priors(
288
+ prior_model,
289
+ video.to(device=device, dtype=torch.float32),
290
+ chunk_size=decode_chunk_size
291
+ ) # T,H,W T,H,W T,3,H,W T,2,H,W
292
+
293
+ if need_resize:
294
+ pred_disparity = F.interpolate(pred_disparity.unsqueeze(1), (height, width), mode='bilinear', align_corners=False).squeeze(1)
295
+ pred_valid_mask = F.interpolate(pred_valid_mask.unsqueeze(1), (height, width), mode='bilinear', align_corners=False).squeeze(1)
296
+ pred_point_map = F.interpolate(pred_point_map, (height, width), mode='bilinear', align_corners=False)
297
+ pred_intrinsic_map = F.interpolate(pred_intrinsic_map, (height, width), mode='bilinear', align_corners=False)
298
+
299
+ if track_time:
300
+ prior_event.record()
301
+ torch.cuda.synchronize()
302
+ elapsed_time_ms = start_event.elapsed_time(prior_event)
303
+ print(f"Elapsed time for computing per-frame prior: {elapsed_time_ms} ms")
304
+ else:
305
+ gc.collect()
306
+ torch.cuda.empty_cache()
307
+
308
+
309
+
310
+ # 3. Encode input video
311
+ if need_resize:
312
+ video = F.interpolate(video, (height, width), mode="bicubic", align_corners=False, antialias=True).clamp(0, 1)
313
+
314
+ video = video.to(device=device, dtype=self.dtype)
315
+ video = video * 2.0 - 1.0 # [0,1] -> [-1,1], in [t, c, h, w]
316
+
317
+
318
+ video_embeddings = self.encode_video(video, chunk_size=decode_chunk_size).unsqueeze(0)
319
+
320
+ prior_latents = self.encode_point_map(
321
+ point_map_vae,
322
+ pred_disparity,
323
+ pred_valid_mask,
324
+ pred_point_map,
325
+ pred_intrinsic_map,
326
+ chunk_size=decode_chunk_size
327
+ ).unsqueeze(0).to(video_embeddings.dtype) # 1,T,C,H,W
328
+
329
+
330
+ # 4. Encode input image using VAE
331
+
332
+ # pdb.set_trace()
333
+ needs_upcasting = (
334
+ self.vae.dtype == torch.float16 and self.vae.config.force_upcast
335
+ )
336
+ if needs_upcasting:
337
+ self.vae.to(dtype=torch.float32)
338
+
339
+ video_latents = self.encode_vae_video(
340
+ video.to(self.vae.dtype),
341
+ chunk_size=decode_chunk_size,
342
+ ).unsqueeze(0).to(video_embeddings.dtype) # [1, t, c, h, w]
343
+
344
+
345
+ if track_time:
346
+ encode_event.record()
347
+ torch.cuda.synchronize()
348
+ elapsed_time_ms = prior_event.elapsed_time(encode_event)
349
+ print(f"Elapsed time for encode prior and frames: {elapsed_time_ms} ms")
350
+ else:
351
+ gc.collect()
352
+ torch.cuda.empty_cache()
353
+
354
+ # cast back to fp16 if needed
355
+ if needs_upcasting:
356
+ self.vae.to(dtype=torch.float16)
357
+
358
+ # 5. Get Added Time IDs
359
+ added_time_ids = self._get_add_time_ids(
360
+ 7,
361
+ 127,
362
+ noise_aug_strength,
363
+ video_embeddings.dtype,
364
+ batch_size,
365
+ 1,
366
+ False,
367
+ ) # [1 or 2, 3]
368
+ added_time_ids = added_time_ids.to(device)
369
+
370
+ # 6. Prepare timesteps
371
+ timestep = 1.6378
372
+ self._num_timesteps = 1
373
+
374
+ # 7. Prepare latent variables
375
+ num_channels_latents = self.unet.config.in_channels
376
+ latents_init = prior_latents # [1, t, c, h, w]
377
+ latents_all = None
378
+
379
+ idx_start = 0
380
+ if overlap > 0:
381
+ weights = torch.linspace(0, 1, overlap, device=device)
382
+ weights = weights.view(1, overlap, 1, 1, 1)
383
+ else:
384
+ weights = None
385
+
386
+ while idx_start < num_frames - overlap:
387
+ idx_end = min(idx_start + window_size, num_frames)
388
+ # 9. Denoising loop
389
+ # latents_init = latents_init.flip(1)
390
+ latents = latents_init[:, idx_start:idx_end]
391
+ video_latents_current = video_latents[:, idx_start:idx_end]
392
+ video_embeddings_current = video_embeddings[:, idx_start:idx_end]
393
+
394
+ latent_model_input = torch.cat(
395
+ [latents, video_latents_current], dim=2
396
+ )
397
+
398
+ model_pred = self.unet(
399
+ latent_model_input,
400
+ timestep,
401
+ encoder_hidden_states=video_embeddings_current,
402
+ added_time_ids=added_time_ids,
403
+ return_dict=False,
404
+ )[0]
405
+
406
+ c_out = -1
407
+ latents = model_pred * c_out
408
+
409
+ if latents_all is None:
410
+ latents_all = latents.clone()
411
+ else:
412
+ if overlap > 0:
413
+ latents_all[:, -overlap:] = latents[
414
+ :, :overlap
415
+ ] * weights + latents_all[:, -overlap:] * (1 - weights)
416
+ latents_all = torch.cat([latents_all, latents[:, overlap:]], dim=1)
417
+
418
+ idx_start += stride
419
+
420
+ latents_all = 1 / self.vae.config.scaling_factor * latents_all.squeeze(0).to(torch.float32)
421
+
422
+ if track_time:
423
+ denoise_event.record()
424
+ torch.cuda.synchronize()
425
+ elapsed_time_ms = encode_event.elapsed_time(denoise_event)
426
+ print(f"Elapsed time for denoise latent: {elapsed_time_ms} ms")
427
+ else:
428
+ gc.collect()
429
+ torch.cuda.empty_cache()
430
+
431
+ point_map, valid_mask = self.decode_point_map(
432
+ point_map_vae,
433
+ latents_all,
434
+ chunk_size=decode_chunk_size,
435
+ force_projection=force_projection,
436
+ force_fixed_focal=force_fixed_focal,
437
+ use_extract_interp=use_extract_interp,
438
+ need_resize=need_resize,
439
+ height=original_height,
440
+ width=original_width)
441
+
442
+ if track_time:
443
+ decode_event.record()
444
+ torch.cuda.synchronize()
445
+ elapsed_time_ms = denoise_event.elapsed_time(decode_event)
446
+ print(f"Elapsed time for decode latent: {elapsed_time_ms} ms")
447
+ else:
448
+ gc.collect()
449
+ torch.cuda.empty_cache()
450
+
451
+ self.maybe_free_model_hooks()
452
+ # t,h,w,3 t,h,w
453
+ return point_map, valid_mask
geometrycrafter/diff_ppl.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Dict, List, Optional, Union
2
+ import gc
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
9
+ _resize_with_antialiasing,
10
+ StableVideoDiffusionPipeline,
11
+ retrieve_timesteps,
12
+ )
13
+ from diffusers.utils import logging
14
+ from kornia.utils import create_meshgrid
15
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
16
+
17
+
18
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
19
+
20
+ @torch.no_grad()
21
+ def normalize_point_map(point_map, valid_mask):
22
+ # T,H,W,3 T,H,W
23
+ norm_factor = (point_map[..., 2] * valid_mask.float()).mean() / (valid_mask.float().mean() + 1e-8)
24
+ norm_factor = norm_factor.clip(min=1e-3)
25
+ return point_map / norm_factor
26
+
27
+ def point_map_xy2intrinsic_map(point_map_xy):
28
+ # *,h,w,2
29
+ height, width = point_map_xy.shape[-3], point_map_xy.shape[-2]
30
+ assert height % 2 == 0
31
+ assert width % 2 == 0
32
+ mesh_grid = create_meshgrid(
33
+ height=height,
34
+ width=width,
35
+ normalized_coordinates=True,
36
+ device=point_map_xy.device,
37
+ dtype=point_map_xy.dtype
38
+ )[0] # h,w,2
39
+ assert mesh_grid.abs().min() > 1e-4
40
+ # *,h,w,2
41
+ mesh_grid = mesh_grid.expand_as(point_map_xy)
42
+ nc = point_map_xy.mean(dim=-2).mean(dim=-2) # *, 2
43
+ nc_map = nc[..., None, None, :].expand_as(point_map_xy)
44
+ nf = ((point_map_xy - nc_map) / mesh_grid).mean(dim=-2).mean(dim=-2)
45
+ nf_map = nf[..., None, None, :].expand_as(point_map_xy)
46
+ # print((mesh_grid * nf_map + nc_map - point_map_xy).abs().max())
47
+
48
+ return torch.cat([nc_map, nf_map], dim=-1)
49
+
50
+ def robust_min_max(tensor, quantile=0.99):
51
+ T, H, W = tensor.shape
52
+ min_vals = []
53
+ max_vals = []
54
+ for i in range(T):
55
+ min_vals.append(torch.quantile(tensor[i], q=1-quantile, interpolation='nearest').item())
56
+ max_vals.append(torch.quantile(tensor[i], q=quantile, interpolation='nearest').item())
57
+ return min(min_vals), max(max_vals)
58
+
59
+ class GeometryCrafterDiffPipeline(StableVideoDiffusionPipeline):
60
+
61
+ @torch.inference_mode()
62
+ def encode_video(
63
+ self,
64
+ video: torch.Tensor,
65
+ chunk_size: int = 14,
66
+ ) -> torch.Tensor:
67
+ """
68
+ :param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
69
+ :param chunk_size: the chunk size to encode video
70
+ :return: image_embeddings in shape of [b, 1024]
71
+ """
72
+
73
+ video_224 = _resize_with_antialiasing(video.float(), (224, 224))
74
+ video_224 = (video_224 + 1.0) / 2.0 # [-1, 1] -> [0, 1]
75
+ embeddings = []
76
+ for i in range(0, video_224.shape[0], chunk_size):
77
+ emb = self.feature_extractor(
78
+ images=video_224[i : i + chunk_size],
79
+ do_normalize=True,
80
+ do_center_crop=False,
81
+ do_resize=False,
82
+ do_rescale=False,
83
+ return_tensors="pt",
84
+ ).pixel_values.to(video.device, dtype=video.dtype)
85
+ embeddings.append(self.image_encoder(emb).image_embeds) # [b, 1024]
86
+
87
+ embeddings = torch.cat(embeddings, dim=0) # [t, 1024]
88
+ return embeddings
89
+
90
+ @torch.inference_mode()
91
+ def encode_vae_video(
92
+ self,
93
+ video: torch.Tensor,
94
+ chunk_size: int = 14,
95
+ ):
96
+ """
97
+ :param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
98
+ :param chunk_size: the chunk size to encode video
99
+ :return: vae latents in shape of [b, c, h, w]
100
+ """
101
+ video_latents = []
102
+ for i in range(0, video.shape[0], chunk_size):
103
+ video_latents.append(
104
+ self.vae.encode(video[i : i + chunk_size]).latent_dist.mode()
105
+ )
106
+ video_latents = torch.cat(video_latents, dim=0)
107
+ return video_latents
108
+
109
+ @torch.inference_mode()
110
+ def produce_priors(self, prior_model, frame, chunk_size=8):
111
+ T, _, H, W = frame.shape
112
+ frame = (frame + 1) / 2
113
+ pred_point_maps = []
114
+ pred_masks = []
115
+ for i in range(0, len(frame), chunk_size):
116
+ pred_p, pred_m = prior_model.forward_image(frame[i:i+chunk_size])
117
+ pred_point_maps.append(pred_p)
118
+ pred_masks.append(pred_m)
119
+ pred_point_maps = torch.cat(pred_point_maps, dim=0)
120
+ pred_masks = torch.cat(pred_masks, dim=0)
121
+
122
+ pred_masks = pred_masks.float() * 2 - 1
123
+
124
+ # T,H,W,3 T,H,W
125
+ pred_point_maps = normalize_point_map(pred_point_maps, pred_masks > 0)
126
+
127
+ pred_disps = 1.0 / pred_point_maps[..., 2].clamp_min(1e-3)
128
+ pred_disps = pred_disps * (pred_masks > 0)
129
+ min_disparity, max_disparity = robust_min_max(pred_disps)
130
+ pred_disps = ((pred_disps - min_disparity) / (max_disparity - min_disparity+1e-4)).clamp(0, 1)
131
+ pred_disps = pred_disps * 2 - 1
132
+
133
+ pred_point_maps[..., :2] = pred_point_maps[..., :2] / (pred_point_maps[..., 2:3] + 1e-7)
134
+ pred_point_maps[..., 2] = torch.log(pred_point_maps[..., 2] + 1e-7) * (pred_masks > 0) # [x/z, y/z, log(z)]
135
+
136
+ pred_intr_maps = point_map_xy2intrinsic_map(pred_point_maps[..., :2]).permute(0,3,1,2) # T,H,W,2
137
+ pred_point_maps = pred_point_maps.permute(0,3,1,2)
138
+
139
+ return pred_disps, pred_masks, pred_point_maps, pred_intr_maps
140
+
141
+ @torch.inference_mode()
142
+ def encode_point_map(self, point_map_vae, disparity, valid_mask, point_map, intrinsic_map, chunk_size=8):
143
+ T, _, H, W = point_map.shape
144
+ latents = []
145
+
146
+ psedo_image = disparity[:, None].repeat(1,3,1,1)
147
+ intrinsic_map = torch.norm(intrinsic_map[:, 2:4], p=2, dim=1, keepdim=False)
148
+
149
+ for i in range(0, T, chunk_size):
150
+ latent_dist = self.vae.encode(psedo_image[i : i + chunk_size].to(self.vae.dtype)).latent_dist
151
+ latent_dist = point_map_vae.encode(
152
+ torch.cat([
153
+ intrinsic_map[i:i+chunk_size, None],
154
+ point_map[i:i+chunk_size, 2:3],
155
+ disparity[i:i+chunk_size, None],
156
+ valid_mask[i:i+chunk_size, None]], dim=1),
157
+ latent_dist
158
+ )
159
+ if isinstance(latent_dist, DiagonalGaussianDistribution):
160
+ latent = latent_dist.mode()
161
+ else:
162
+ latent = latent_dist
163
+
164
+ assert isinstance(latent, torch.Tensor)
165
+ latents.append(latent)
166
+ latents = torch.cat(latents, dim=0)
167
+ latents = latents * self.vae.config.scaling_factor
168
+ return latents
169
+
170
+ @torch.no_grad()
171
+ def decode_point_map(self, point_map_vae, latents, chunk_size=8, force_projection=True, force_fixed_focal=True, use_extract_interp=False, need_resize=False, height=None, width=None):
172
+ T = latents.shape[0]
173
+ rec_intrinsic_maps = []
174
+ rec_depth_maps = []
175
+ rec_valid_masks = []
176
+ for i in range(0, T, chunk_size):
177
+ lat = latents[i:i+chunk_size]
178
+ rec_imap, rec_dmap, rec_vmask = point_map_vae.decode(
179
+ lat,
180
+ num_frames=lat.shape[0],
181
+ )
182
+ rec_intrinsic_maps.append(rec_imap)
183
+ rec_depth_maps.append(rec_dmap)
184
+ rec_valid_masks.append(rec_vmask)
185
+
186
+ rec_intrinsic_maps = torch.cat(rec_intrinsic_maps, dim=0)
187
+ rec_depth_maps = torch.cat(rec_depth_maps, dim=0)
188
+ rec_valid_masks = torch.cat(rec_valid_masks, dim=0)
189
+
190
+ if need_resize:
191
+ rec_depth_maps = F.interpolate(rec_depth_maps, (height, width), mode='nearest-exact') if use_extract_interp else F.interpolate(rec_depth_maps, (height, width), mode='bilinear', align_corners=False)
192
+ rec_valid_masks = F.interpolate(rec_valid_masks, (height, width), mode='nearest-exact') if use_extract_interp else F.interpolate(rec_valid_masks, (height, width), mode='bilinear', align_corners=False)
193
+ rec_intrinsic_maps = F.interpolate(rec_intrinsic_maps, (height, width), mode='bilinear', align_corners=False)
194
+
195
+ H, W = rec_intrinsic_maps.shape[-2], rec_intrinsic_maps.shape[-1]
196
+ mesh_grid = create_meshgrid(
197
+ H, W,
198
+ normalized_coordinates=True
199
+ ).to(rec_intrinsic_maps.device, rec_intrinsic_maps.dtype, non_blocking=True)
200
+ # 1,h,w,2
201
+ rec_intrinsic_maps = torch.cat([rec_intrinsic_maps * W / np.sqrt(W**2+H**2), rec_intrinsic_maps * H / np.sqrt(W**2+H**2)], dim=1) # t,2,h,w
202
+ mesh_grid = mesh_grid.permute(0,3,1,2)
203
+ rec_valid_masks = rec_valid_masks.squeeze(1) > 0
204
+
205
+ if force_projection:
206
+ if force_fixed_focal:
207
+ nfx = (rec_intrinsic_maps[:, 0, :, :] * rec_valid_masks.float()).mean() / (rec_valid_masks.float().mean() + 1e-4)
208
+ nfy = (rec_intrinsic_maps[:, 1, :, :] * rec_valid_masks.float()).mean() / (rec_valid_masks.float().mean() + 1e-4)
209
+ rec_intrinsic_maps = torch.tensor([nfx, nfy], device=rec_intrinsic_maps.device)[None, :, None, None].repeat(T, 1, 1, 1)
210
+ else:
211
+ nfx = (rec_intrinsic_maps[:, 0, :, :] * rec_valid_masks.float()).mean(dim=[-1, -2]) / (rec_valid_masks.float().mean(dim=[-1, -2]) + 1e-4)
212
+ nfy = (rec_intrinsic_maps[:, 1, :, :] * rec_valid_masks.float()).mean(dim=[-1, -2]) / (rec_valid_masks.float().mean(dim=[-1, -2]) + 1e-4)
213
+ rec_intrinsic_maps = torch.stack([nfx, nfy], dim=-1)[:, :, None, None]
214
+ # t,2,1,1
215
+
216
+ rec_point_maps = torch.cat([rec_intrinsic_maps * mesh_grid, rec_depth_maps], dim=1).permute(0,2,3,1)
217
+ xy, z = rec_point_maps.split([2, 1], dim=-1)
218
+ z = torch.clamp_max(z, 10) # for numerical stability
219
+ z = torch.exp(z)
220
+ rec_point_maps = torch.cat([xy * z, z], dim=-1)
221
+
222
+ return rec_point_maps, rec_valid_masks
223
+
224
+
225
+ @torch.no_grad()
226
+ def __call__(
227
+ self,
228
+ video: Union[np.ndarray, torch.Tensor],
229
+ point_map_vae,
230
+ prior_model,
231
+ height: int = 320,
232
+ width: int = 640,
233
+ num_inference_steps: int = 5,
234
+ guidance_scale: float = 1.0,
235
+ window_size: Optional[int] = 14,
236
+ noise_aug_strength: float = 0.02,
237
+ decode_chunk_size: Optional[int] = None,
238
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
239
+ latents: Optional[torch.FloatTensor] = None,
240
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
241
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
242
+ overlap: int = 4,
243
+ force_projection: bool = True,
244
+ force_fixed_focal: bool = True,
245
+ use_extract_interp: bool = False,
246
+ track_time: bool = False,
247
+ ):
248
+
249
+ # video: in shape [t, h, w, c] if np.ndarray or [t, c, h, w] if torch.Tensor, in range [0, 1]
250
+
251
+ # 0. Default height and width to unet
252
+ if isinstance(video, np.ndarray):
253
+ video = torch.from_numpy(video.transpose(0, 3, 1, 2))
254
+ else:
255
+ assert isinstance(video, torch.Tensor)
256
+ height = height or video.shape[-2]
257
+ width = width or video.shape[-1]
258
+ original_height = video.shape[-2]
259
+ original_width = video.shape[-1]
260
+ num_frames = video.shape[0]
261
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 8
262
+ if num_frames <= window_size:
263
+ window_size = num_frames
264
+ overlap = 0
265
+ stride = window_size - overlap
266
+
267
+ # 1. Check inputs. Raise error if not correct
268
+ assert height % 64 == 0 and width % 64 == 0
269
+ if original_height != height or original_width != width:
270
+ need_resize = True
271
+ else:
272
+ need_resize = False
273
+
274
+ # 2. Define call parameters
275
+ batch_size = 1
276
+ device = self._execution_device
277
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
278
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
279
+ # corresponds to doing no classifier free guidance.
280
+ self._guidance_scale = guidance_scale
281
+
282
+ if track_time:
283
+ start_event = torch.cuda.Event(enable_timing=True)
284
+ prior_event = torch.cuda.Event(enable_timing=True)
285
+ encode_event = torch.cuda.Event(enable_timing=True)
286
+ denoise_event = torch.cuda.Event(enable_timing=True)
287
+ decode_event = torch.cuda.Event(enable_timing=True)
288
+ start_event.record()
289
+
290
+ # 3. Encode input video
291
+ pred_disparity, pred_valid_mask, pred_point_map, pred_intrinsic_map = self.produce_priors(
292
+ prior_model,
293
+ video.to(device=device, dtype=torch.float32),
294
+ chunk_size=decode_chunk_size
295
+ ) # T,H,W T,H,W T,3,H,W T,2,H,W
296
+
297
+ if need_resize:
298
+ pred_disparity = F.interpolate(pred_disparity.unsqueeze(1), (height, width), mode='bilinear', align_corners=False).squeeze(1)
299
+ pred_valid_mask = F.interpolate(pred_valid_mask.unsqueeze(1), (height, width), mode='bilinear', align_corners=False).squeeze(1)
300
+ pred_point_map = F.interpolate(pred_point_map, (height, width), mode='bilinear', align_corners=False)
301
+ pred_intrinsic_map = F.interpolate(pred_intrinsic_map, (height, width), mode='bilinear', align_corners=False)
302
+
303
+
304
+ if track_time:
305
+ prior_event.record()
306
+ torch.cuda.synchronize()
307
+ elapsed_time_ms = start_event.elapsed_time(prior_event)
308
+ print(f"Elapsed time for computing per-frame prior: {elapsed_time_ms} ms")
309
+ else:
310
+ gc.collect()
311
+ torch.cuda.empty_cache()
312
+
313
+
314
+ # 3. Encode input video
315
+ if need_resize:
316
+ video = F.interpolate(video, (height, width), mode="bicubic", align_corners=False, antialias=True).clamp(0, 1)
317
+ video = video.to(device=device, dtype=self.dtype)
318
+ video = video * 2.0 - 1.0 # [0,1] -> [-1,1], in [t, c, h, w]
319
+
320
+ video_embeddings = self.encode_video(video, chunk_size=decode_chunk_size).unsqueeze(0)
321
+ prior_latents = self.encode_point_map(
322
+ point_map_vae,
323
+ pred_disparity,
324
+ pred_valid_mask,
325
+ pred_point_map,
326
+ pred_intrinsic_map,
327
+ chunk_size=decode_chunk_size
328
+ ).unsqueeze(0).to(video_embeddings.dtype) # 1,T,C,H,W
329
+
330
+ # 4. Encode input image using VAE
331
+
332
+ # pdb.set_trace()
333
+ needs_upcasting = (
334
+ self.vae.dtype == torch.float16 and self.vae.config.force_upcast
335
+ )
336
+ if needs_upcasting:
337
+ self.vae.to(dtype=torch.float32)
338
+
339
+ video_latents = self.encode_vae_video(
340
+ video.to(self.vae.dtype),
341
+ chunk_size=decode_chunk_size,
342
+ ).unsqueeze(0).to(video_embeddings.dtype) # [1, t, c, h, w]
343
+
344
+ torch.cuda.empty_cache()
345
+
346
+ if track_time:
347
+ encode_event.record()
348
+ torch.cuda.synchronize()
349
+ elapsed_time_ms = prior_event.elapsed_time(encode_event)
350
+ print(f"Elapsed time for encode prior and frames: {elapsed_time_ms} ms")
351
+ else:
352
+ gc.collect()
353
+ torch.cuda.empty_cache()
354
+
355
+ # cast back to fp16 if needed
356
+ if needs_upcasting:
357
+ self.vae.to(dtype=torch.float16)
358
+
359
+ # 5. Get Added Time IDs
360
+ added_time_ids = self._get_add_time_ids(
361
+ 7,
362
+ 127,
363
+ noise_aug_strength,
364
+ video_embeddings.dtype,
365
+ batch_size,
366
+ 1,
367
+ False,
368
+ ) # [1 or 2, 3]
369
+ added_time_ids = added_time_ids.to(device)
370
+
371
+ # 6. Prepare timesteps
372
+ timesteps, num_inference_steps = retrieve_timesteps(
373
+ self.scheduler, num_inference_steps, device, None, None
374
+ )
375
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
376
+ self._num_timesteps = len(timesteps)
377
+
378
+ # 7. Prepare latent variables
379
+ # num_channels_latents = self.unet.config.in_channels - prior_latents.shape[1]
380
+ num_channels_latents = 8
381
+ latents_init = self.prepare_latents(
382
+ batch_size,
383
+ window_size,
384
+ num_channels_latents,
385
+ height,
386
+ width,
387
+ video_embeddings.dtype,
388
+ device,
389
+ generator,
390
+ latents,
391
+ ) # [1, t, c, h, w]
392
+ latents_all = None
393
+
394
+ idx_start = 0
395
+ if overlap > 0:
396
+ weights = torch.linspace(0, 1, overlap, device=device)
397
+ weights = weights.view(1, overlap, 1, 1, 1)
398
+ else:
399
+ weights = None
400
+
401
+ while idx_start < num_frames - overlap:
402
+ idx_end = min(idx_start + window_size, num_frames)
403
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
404
+ # 9. Denoising loop
405
+ # latents_init = latents_init.flip(1)
406
+ latents = latents_init[:, : idx_end - idx_start].clone()
407
+ latents_init = torch.cat(
408
+ [latents_init[:, -overlap:], latents_init[:, :stride]], dim=1
409
+ )
410
+
411
+ video_latents_current = video_latents[:, idx_start:idx_end]
412
+ prior_latents_current = prior_latents[:, idx_start:idx_end]
413
+ video_embeddings_current = video_embeddings[:, idx_start:idx_end]
414
+
415
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
416
+ for i, t in enumerate(timesteps):
417
+ if latents_all is not None and i == 0:
418
+ latents[:, :overlap] = (
419
+ latents_all[:, -overlap:]
420
+ + latents[:, :overlap]
421
+ / self.scheduler.init_noise_sigma
422
+ * self.scheduler.sigmas[i]
423
+ )
424
+
425
+ latent_model_input = latents
426
+
427
+ latent_model_input = self.scheduler.scale_model_input(
428
+ latent_model_input, t
429
+ ) # [1 or 2, t, c, h, w]
430
+ latent_model_input = torch.cat(
431
+ [latent_model_input, video_latents_current, prior_latents_current], dim=2
432
+ )
433
+ noise_pred = self.unet(
434
+ latent_model_input,
435
+ t,
436
+ encoder_hidden_states=video_embeddings_current,
437
+ added_time_ids=added_time_ids,
438
+ return_dict=False,
439
+ )[0]
440
+ # pdb.set_trace()
441
+ # perform guidance
442
+ if self.do_classifier_free_guidance:
443
+ latent_model_input = latents
444
+ latent_model_input = self.scheduler.scale_model_input(
445
+ latent_model_input, t
446
+ )
447
+ latent_model_input = torch.cat(
448
+ [latent_model_input, torch.zeros_like(latent_model_input), torch.zeros_like(latent_model_input)],
449
+ dim=2,
450
+ )
451
+ noise_pred_uncond = self.unet(
452
+ latent_model_input,
453
+ t,
454
+ encoder_hidden_states=torch.zeros_like(
455
+ video_embeddings_current
456
+ ),
457
+ added_time_ids=added_time_ids,
458
+ return_dict=False,
459
+ )[0]
460
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
461
+ noise_pred - noise_pred_uncond
462
+ )
463
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
464
+
465
+ if callback_on_step_end is not None:
466
+ callback_kwargs = {}
467
+ for k in callback_on_step_end_tensor_inputs:
468
+ callback_kwargs[k] = locals()[k]
469
+ callback_outputs = callback_on_step_end(
470
+ self, i, t, callback_kwargs
471
+ )
472
+
473
+ latents = callback_outputs.pop("latents", latents)
474
+
475
+ if i == len(timesteps) - 1 or (
476
+ (i + 1) > num_warmup_steps
477
+ and (i + 1) % self.scheduler.order == 0
478
+ ):
479
+ progress_bar.update()
480
+
481
+ if latents_all is None:
482
+ latents_all = latents.clone()
483
+ else:
484
+ if overlap > 0:
485
+ latents_all[:, -overlap:] = latents[
486
+ :, :overlap
487
+ ] * weights + latents_all[:, -overlap:] * (1 - weights)
488
+ latents_all = torch.cat([latents_all, latents[:, overlap:]], dim=1)
489
+
490
+ idx_start += stride
491
+
492
+ latents_all = 1 / self.vae.config.scaling_factor * latents_all.squeeze(0).to(torch.float32)
493
+
494
+ if track_time:
495
+ denoise_event.record()
496
+ torch.cuda.synchronize()
497
+ elapsed_time_ms = encode_event.elapsed_time(denoise_event)
498
+ print(f"Elapsed time for denoise latent: {elapsed_time_ms} ms")
499
+ else:
500
+ gc.collect()
501
+ torch.cuda.empty_cache()
502
+
503
+ point_map, valid_mask = self.decode_point_map(
504
+ point_map_vae,
505
+ latents_all,
506
+ chunk_size=decode_chunk_size,
507
+ force_projection=force_projection,
508
+ force_fixed_focal=force_fixed_focal,
509
+ use_extract_interp=use_extract_interp,
510
+ need_resize=need_resize,
511
+ height=original_height,
512
+ width=original_width)
513
+
514
+
515
+ if track_time:
516
+ decode_event.record()
517
+ torch.cuda.synchronize()
518
+ elapsed_time_ms = denoise_event.elapsed_time(decode_event)
519
+ print(f"Elapsed time for decode latent: {elapsed_time_ms} ms")
520
+ else:
521
+ gc.collect()
522
+ torch.cuda.empty_cache()
523
+
524
+ self.maybe_free_model_hooks()
525
+ # t,h,w,3 t,h,w
526
+ return point_map, valid_mask
geometrycrafter/pmap_vae.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
6
+ from diffusers.utils.accelerate_utils import apply_forward_hook
7
+ from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
8
+ from diffusers.models.modeling_utils import ModelMixin
9
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution, Encoder
10
+ from diffusers.utils import is_torch_version
11
+ from diffusers.models.unets.unet_3d_blocks import UpBlockTemporalDecoder, MidBlockTemporalDecoder
12
+ from diffusers.models.resnet import SpatioTemporalResBlock
13
+
14
+ def zero_module(module):
15
+ """
16
+ Zero out the parameters of a module and return it.
17
+ """
18
+ for p in module.parameters():
19
+ p.detach().zero_()
20
+ return module
21
+
22
+ class PMapTemporalDecoder(nn.Module):
23
+ def __init__(
24
+ self,
25
+ in_channels: int = 4,
26
+ out_channels: Tuple[int] = (1, 1, 1),
27
+ block_out_channels: Tuple[int] = (128, 256, 512, 512),
28
+ layers_per_block: int = 2,
29
+ ):
30
+ super().__init__()
31
+
32
+ self.conv_in = nn.Conv2d(
33
+ in_channels,
34
+ block_out_channels[-1],
35
+ kernel_size=3,
36
+ stride=1,
37
+ padding=1
38
+ )
39
+ self.mid_block = MidBlockTemporalDecoder(
40
+ num_layers=layers_per_block,
41
+ in_channels=block_out_channels[-1],
42
+ out_channels=block_out_channels[-1],
43
+ attention_head_dim=block_out_channels[-1],
44
+ )
45
+
46
+ # up
47
+ self.up_blocks = nn.ModuleList([])
48
+ reversed_block_out_channels = list(reversed(block_out_channels))
49
+ output_channel = reversed_block_out_channels[0]
50
+ for i in range(len(block_out_channels)):
51
+ prev_output_channel = output_channel
52
+ output_channel = reversed_block_out_channels[i]
53
+ is_final_block = i == len(block_out_channels) - 1
54
+ up_block = UpBlockTemporalDecoder(
55
+ num_layers=layers_per_block + 1,
56
+ in_channels=prev_output_channel,
57
+ out_channels=output_channel,
58
+ add_upsample=not is_final_block,
59
+ )
60
+ self.up_blocks.append(up_block)
61
+ prev_output_channel = output_channel
62
+
63
+ self.out_blocks = nn.ModuleList([])
64
+ self.time_conv_outs = nn.ModuleList([])
65
+ for out_channel in out_channels:
66
+ self.out_blocks.append(
67
+ nn.ModuleList([
68
+ nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-6),
69
+ nn.ReLU(inplace=True),
70
+ nn.Conv2d(
71
+ block_out_channels[0],
72
+ block_out_channels[0] // 2,
73
+ kernel_size=3,
74
+ padding=1
75
+ ),
76
+ SpatioTemporalResBlock(
77
+ in_channels=block_out_channels[0] // 2,
78
+ out_channels=block_out_channels[0] // 2,
79
+ temb_channels=None,
80
+ eps=1e-6,
81
+ temporal_eps=1e-5,
82
+ merge_factor=0.0,
83
+ merge_strategy="learned",
84
+ switch_spatial_to_temporal_mix=True
85
+ ),
86
+ nn.ReLU(inplace=True),
87
+ nn.Conv2d(
88
+ block_out_channels[0] // 2,
89
+ out_channel,
90
+ kernel_size=1,
91
+ )
92
+ ])
93
+ )
94
+
95
+ conv_out_kernel_size = (3, 1, 1)
96
+ padding = [int(k // 2) for k in conv_out_kernel_size]
97
+ self.time_conv_outs.append(nn.Conv3d(
98
+ in_channels=out_channel,
99
+ out_channels=out_channel,
100
+ kernel_size=conv_out_kernel_size,
101
+ padding=padding,
102
+ ))
103
+
104
+ self.gradient_checkpointing = False
105
+
106
+ def forward(
107
+ self,
108
+ sample: torch.Tensor,
109
+ image_only_indicator: torch.Tensor,
110
+ num_frames: int = 1,
111
+ ):
112
+ sample = self.conv_in(sample)
113
+
114
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
115
+
116
+ if self.training and self.gradient_checkpointing:
117
+ def create_custom_forward(module):
118
+ def custom_forward(*inputs):
119
+ return module(*inputs)
120
+
121
+ return custom_forward
122
+
123
+ if is_torch_version(">=", "1.11.0"):
124
+ # middle
125
+ sample = torch.utils.checkpoint.checkpoint(
126
+ create_custom_forward(self.mid_block),
127
+ sample,
128
+ image_only_indicator,
129
+ use_reentrant=False,
130
+ )
131
+ sample = sample.to(upscale_dtype)
132
+
133
+ # up
134
+ for up_block in self.up_blocks:
135
+ sample = torch.utils.checkpoint.checkpoint(
136
+ create_custom_forward(up_block),
137
+ sample,
138
+ image_only_indicator,
139
+ use_reentrant=False,
140
+ )
141
+ else:
142
+ # middle
143
+ sample = torch.utils.checkpoint.checkpoint(
144
+ create_custom_forward(self.mid_block),
145
+ sample,
146
+ image_only_indicator,
147
+ )
148
+ sample = sample.to(upscale_dtype)
149
+
150
+ # up
151
+ for up_block in self.up_blocks:
152
+ sample = torch.utils.checkpoint.checkpoint(
153
+ create_custom_forward(up_block),
154
+ sample,
155
+ image_only_indicator,
156
+ )
157
+ else:
158
+ # middle
159
+ sample = self.mid_block(sample, image_only_indicator=image_only_indicator)
160
+ sample = sample.to(upscale_dtype)
161
+
162
+ # up
163
+ for up_block in self.up_blocks:
164
+ sample = up_block(sample, image_only_indicator=image_only_indicator)
165
+
166
+ # post-process
167
+
168
+ output = []
169
+
170
+ for out_block, time_conv_out in zip(self.out_blocks, self.time_conv_outs):
171
+ x = sample
172
+ for layer in out_block:
173
+ if isinstance(layer, SpatioTemporalResBlock):
174
+ x = layer(x, None, image_only_indicator)
175
+ else:
176
+ x = layer(x)
177
+
178
+
179
+ batch_frames, channels, height, width = x.shape
180
+ batch_size = batch_frames // num_frames
181
+ x = x[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
182
+ x = time_conv_out(x)
183
+ x = x.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
184
+ output.append(x)
185
+
186
+ return output
187
+
188
+ class PMapAutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
189
+
190
+ _supports_gradient_checkpointing = True
191
+
192
+ @register_to_config
193
+ def __init__(
194
+ self,
195
+ in_channels: int = 4,
196
+ latent_channels: int = 4,
197
+ enc_down_block_types: Tuple[str] = (
198
+ "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"
199
+ ),
200
+ enc_block_out_channels: Tuple[int] = (128, 256, 512, 512),
201
+ enc_layers_per_block: int = 2,
202
+ dec_block_out_channels: Tuple[int] = (128, 256, 512, 512),
203
+ dec_layers_per_block: int = 2,
204
+ out_channels: Tuple[int] = (1, 1, 1),
205
+ mid_block_add_attention: bool = True,
206
+ offset_scale_factor: float = 0.1,
207
+ **kwargs
208
+ ):
209
+ super().__init__()
210
+
211
+ self.encoder = Encoder(
212
+ in_channels=in_channels,
213
+ out_channels=latent_channels,
214
+ down_block_types=enc_down_block_types,
215
+ block_out_channels=enc_block_out_channels,
216
+ layers_per_block=enc_layers_per_block,
217
+ double_z=False,
218
+ mid_block_add_attention=mid_block_add_attention
219
+ )
220
+ zero_module(self.encoder.conv_out)
221
+
222
+ self.offset_scale_factor = offset_scale_factor
223
+
224
+ self.decoder = PMapTemporalDecoder(
225
+ in_channels=latent_channels,
226
+ block_out_channels=dec_block_out_channels,
227
+ layers_per_block=dec_layers_per_block,
228
+ out_channels=out_channels
229
+ )
230
+
231
+ def _set_gradient_checkpointing(self, module, value=False):
232
+ if isinstance(module, (Encoder, PMapTemporalDecoder)):
233
+ module.gradient_checkpointing = value
234
+
235
+ @property
236
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
237
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
238
+ r"""
239
+ Returns:
240
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
241
+ indexed by its weight name.
242
+ """
243
+ # set recursively
244
+ processors = {}
245
+
246
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
247
+ if hasattr(module, "get_processor"):
248
+ processors[f"{name}.processor"] = module.get_processor()
249
+
250
+ for sub_name, child in module.named_children():
251
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
252
+
253
+ return processors
254
+
255
+ for name, module in self.named_children():
256
+ fn_recursive_add_processors(name, module, processors)
257
+
258
+ return processors
259
+
260
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
261
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
262
+ r"""
263
+ Sets the attention processor to use to compute attention.
264
+
265
+ Parameters:
266
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
267
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
268
+ for **all** `Attention` layers.
269
+
270
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
271
+ processor. This is strongly recommended when setting trainable attention processors.
272
+
273
+ """
274
+ count = len(self.attn_processors.keys())
275
+
276
+ if isinstance(processor, dict) and len(processor) != count:
277
+ raise ValueError(
278
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
279
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
280
+ )
281
+
282
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
283
+ if hasattr(module, "set_processor"):
284
+ if not isinstance(processor, dict):
285
+ module.set_processor(processor)
286
+ else:
287
+ module.set_processor(processor.pop(f"{name}.processor"))
288
+
289
+ for sub_name, child in module.named_children():
290
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
291
+
292
+ for name, module in self.named_children():
293
+ fn_recursive_attn_processor(name, module, processor)
294
+
295
+ def set_default_attn_processor(self):
296
+ """
297
+ Disables custom attention processors and sets the default attention implementation.
298
+ """
299
+ if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
300
+ processor = AttnProcessor()
301
+ else:
302
+ raise ValueError(
303
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
304
+ )
305
+
306
+ self.set_attn_processor(processor)
307
+
308
+ @apply_forward_hook
309
+ def encode(
310
+ self,
311
+ x: torch.Tensor,
312
+ latent_dist: DiagonalGaussianDistribution
313
+ ) -> DiagonalGaussianDistribution:
314
+ h = self.encoder(x)
315
+ offset = h * self.offset_scale_factor
316
+ param = latent_dist.parameters.to(h.dtype)
317
+ mean, logvar = torch.chunk(param, 2, dim=1)
318
+ posterior = DiagonalGaussianDistribution(torch.cat([mean + offset, logvar], dim=1))
319
+ return posterior
320
+
321
+ @apply_forward_hook
322
+ def decode(
323
+ self,
324
+ z: torch.Tensor,
325
+ num_frames: int
326
+ ) -> torch.Tensor:
327
+ batch_size = z.shape[0] // num_frames
328
+ image_only_indicator = torch.zeros(batch_size, num_frames, dtype=z.dtype, device=z.device)
329
+ decoded = self.decoder(z, num_frames=num_frames, image_only_indicator=image_only_indicator)
330
+ return decoded
geometrycrafter/unet.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Tuple
2
+
3
+ import torch
4
+ from diffusers import UNetSpatioTemporalConditionModel
5
+ from diffusers.models.unets.unet_spatio_temporal_condition import UNetSpatioTemporalConditionOutput
6
+ from diffusers.utils import is_torch_version
7
+
8
+
9
+ class UNetSpatioTemporalConditionModelVid2vid(
10
+ UNetSpatioTemporalConditionModel
11
+ ):
12
+ def enable_gradient_checkpointing(self):
13
+ self.gradient_checkpointing = True
14
+
15
+ def disable_gradient_checkpointing(self):
16
+ self.gradient_checkpointing = False
17
+
18
+ def forward(
19
+ self,
20
+ sample: torch.Tensor,
21
+ timestep: Union[torch.Tensor, float, int],
22
+ encoder_hidden_states: torch.Tensor,
23
+ added_time_ids: torch.Tensor,
24
+ return_dict: bool = True,
25
+ ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
26
+
27
+ # 1. time
28
+ timesteps = timestep
29
+ if not torch.is_tensor(timesteps):
30
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
31
+ # This would be a good case for the `match` statement (Python 3.10+)
32
+ is_mps = sample.device.type == "mps"
33
+ if isinstance(timestep, float):
34
+ dtype = torch.float32 if is_mps else torch.float64
35
+ else:
36
+ dtype = torch.int32 if is_mps else torch.int64
37
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
38
+ elif len(timesteps.shape) == 0:
39
+ timesteps = timesteps[None].to(sample.device)
40
+
41
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
42
+ batch_size, num_frames = sample.shape[:2]
43
+ timesteps = timesteps.expand(batch_size)
44
+
45
+ t_emb = self.time_proj(timesteps)
46
+
47
+ # `Timesteps` does not contain any weights and will always return f32 tensors
48
+ # but time_embedding might actually be running in fp16. so we need to cast here.
49
+ # there might be better ways to encapsulate this.
50
+ t_emb = t_emb.to(dtype=self.conv_in.weight.dtype)
51
+
52
+ emb = self.time_embedding(t_emb) # [batch_size * num_frames, channels]
53
+
54
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
55
+ time_embeds = time_embeds.reshape((batch_size, -1))
56
+ time_embeds = time_embeds.to(emb.dtype)
57
+ aug_emb = self.add_embedding(time_embeds)
58
+ emb = emb + aug_emb
59
+
60
+ # Flatten the batch and frames dimensions
61
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
62
+ sample = sample.flatten(0, 1)
63
+ # Repeat the embeddings num_video_frames times
64
+ # emb: [batch, channels] -> [batch * frames, channels]
65
+ emb = emb.repeat_interleave(num_frames, dim=0)
66
+ # encoder_hidden_states: [batch, frames, channels] -> [batch * frames, 1, channels]
67
+ encoder_hidden_states = encoder_hidden_states.flatten(0, 1).unsqueeze(1)
68
+
69
+ # 2. pre-process
70
+ sample = sample.to(dtype=self.conv_in.weight.dtype)
71
+ assert sample.dtype == self.conv_in.weight.dtype, (
72
+ f"sample.dtype: {sample.dtype}, "
73
+ f"self.conv_in.weight.dtype: {self.conv_in.weight.dtype}"
74
+ )
75
+ sample = self.conv_in(sample)
76
+
77
+ image_only_indicator = torch.zeros(
78
+ batch_size, num_frames, dtype=sample.dtype, device=sample.device
79
+ )
80
+
81
+ down_block_res_samples = (sample,)
82
+
83
+ if self.training and self.gradient_checkpointing:
84
+ def create_custom_forward(module):
85
+ def custom_forward(*inputs):
86
+ return module(*inputs)
87
+
88
+ return custom_forward
89
+
90
+ if is_torch_version(">=", "1.11.0"):
91
+
92
+ for downsample_block in self.down_blocks:
93
+ if (
94
+ hasattr(downsample_block, "has_cross_attention")
95
+ and downsample_block.has_cross_attention
96
+ ):
97
+ sample, res_samples = torch.utils.checkpoint.checkpoint(
98
+ create_custom_forward(downsample_block),
99
+ sample,
100
+ emb,
101
+ encoder_hidden_states,
102
+ image_only_indicator,
103
+ use_reentrant=False,
104
+ )
105
+ else:
106
+ sample, res_samples = torch.utils.checkpoint.checkpoint(
107
+ create_custom_forward(downsample_block),
108
+ sample,
109
+ emb,
110
+ image_only_indicator,
111
+ use_reentrant=False,
112
+ )
113
+ down_block_res_samples += res_samples
114
+
115
+ # 4. mid
116
+ sample = torch.utils.checkpoint.checkpoint(
117
+ create_custom_forward(self.mid_block),
118
+ sample,
119
+ emb,
120
+ encoder_hidden_states,
121
+ image_only_indicator,
122
+ use_reentrant=False,
123
+ )
124
+
125
+ # 5. up
126
+ for i, upsample_block in enumerate(self.up_blocks):
127
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
128
+ down_block_res_samples = down_block_res_samples[
129
+ : -len(upsample_block.resnets)
130
+ ]
131
+
132
+ if (
133
+ hasattr(upsample_block, "has_cross_attention")
134
+ and upsample_block.has_cross_attention
135
+ ):
136
+ sample = torch.utils.checkpoint.checkpoint(
137
+ create_custom_forward(upsample_block),
138
+ sample,
139
+ res_samples,
140
+ emb,
141
+ encoder_hidden_states,
142
+ image_only_indicator,
143
+ use_reentrant=False,
144
+ )
145
+ else:
146
+ sample = torch.utils.checkpoint.checkpoint(
147
+ create_custom_forward(upsample_block),
148
+ sample,
149
+ res_samples,
150
+ emb,
151
+ image_only_indicator,
152
+ use_reentrant=False,
153
+ )
154
+ else:
155
+
156
+ for downsample_block in self.down_blocks:
157
+ if (
158
+ hasattr(downsample_block, "has_cross_attention")
159
+ and downsample_block.has_cross_attention
160
+ ):
161
+ sample, res_samples = torch.utils.checkpoint.checkpoint(
162
+ create_custom_forward(downsample_block),
163
+ sample,
164
+ emb,
165
+ encoder_hidden_states,
166
+ image_only_indicator,
167
+ )
168
+ else:
169
+ sample, res_samples = torch.utils.checkpoint.checkpoint(
170
+ create_custom_forward(downsample_block),
171
+ sample,
172
+ emb,
173
+ image_only_indicator,
174
+ )
175
+ down_block_res_samples += res_samples
176
+
177
+ # 4. mid
178
+ sample = torch.utils.checkpoint.checkpoint(
179
+ create_custom_forward(self.mid_block),
180
+ sample,
181
+ emb,
182
+ encoder_hidden_states,
183
+ image_only_indicator,
184
+ )
185
+
186
+ # 5. up
187
+ for i, upsample_block in enumerate(self.up_blocks):
188
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
189
+ down_block_res_samples = down_block_res_samples[
190
+ : -len(upsample_block.resnets)
191
+ ]
192
+
193
+ if (
194
+ hasattr(upsample_block, "has_cross_attention")
195
+ and upsample_block.has_cross_attention
196
+ ):
197
+ sample = torch.utils.checkpoint.checkpoint(
198
+ create_custom_forward(upsample_block),
199
+ sample,
200
+ res_samples,
201
+ emb,
202
+ encoder_hidden_states,
203
+ image_only_indicator,
204
+ )
205
+ else:
206
+ sample = torch.utils.checkpoint.checkpoint(
207
+ create_custom_forward(upsample_block),
208
+ sample,
209
+ res_samples,
210
+ emb,
211
+ image_only_indicator,
212
+ )
213
+
214
+ else:
215
+ for downsample_block in self.down_blocks:
216
+ if (
217
+ hasattr(downsample_block, "has_cross_attention")
218
+ and downsample_block.has_cross_attention
219
+ ):
220
+ sample, res_samples = downsample_block(
221
+ hidden_states=sample,
222
+ temb=emb,
223
+ encoder_hidden_states=encoder_hidden_states,
224
+ image_only_indicator=image_only_indicator,
225
+ )
226
+
227
+ else:
228
+ sample, res_samples = downsample_block(
229
+ hidden_states=sample,
230
+ temb=emb,
231
+ image_only_indicator=image_only_indicator,
232
+ )
233
+
234
+ down_block_res_samples += res_samples
235
+
236
+ # 4. mid
237
+ sample = self.mid_block(
238
+ hidden_states=sample,
239
+ temb=emb,
240
+ encoder_hidden_states=encoder_hidden_states,
241
+ image_only_indicator=image_only_indicator,
242
+ )
243
+
244
+ # 5. up
245
+ for i, upsample_block in enumerate(self.up_blocks):
246
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
247
+ down_block_res_samples = down_block_res_samples[
248
+ : -len(upsample_block.resnets)
249
+ ]
250
+
251
+ if (
252
+ hasattr(upsample_block, "has_cross_attention")
253
+ and upsample_block.has_cross_attention
254
+ ):
255
+ sample = upsample_block(
256
+ hidden_states=sample,
257
+ res_hidden_states_tuple=res_samples,
258
+ temb=emb,
259
+ encoder_hidden_states=encoder_hidden_states,
260
+ image_only_indicator=image_only_indicator,
261
+ )
262
+ else:
263
+ sample = upsample_block(
264
+ hidden_states=sample,
265
+ res_hidden_states_tuple=res_samples,
266
+ temb=emb,
267
+ image_only_indicator=image_only_indicator,
268
+ )
269
+
270
+ # 6. post-process
271
+ sample = self.conv_norm_out(sample)
272
+ sample = self.conv_act(sample)
273
+ sample = self.conv_out(sample)
274
+
275
+ # 7. Reshape back to original shape
276
+ sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
277
+
278
+ if not return_dict:
279
+ return (sample,)
280
+
281
+ return UNetSpatioTemporalConditionOutput(sample=sample)
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.3.1
2
+ diffusers==0.31.0
3
+ numpy==2.0.1
4
+ matplotlib==3.9.2
5
+ transformers==4.48.0
6
+ accelerate==1.1.1
7
+ xformers==0.0.27
8
+ mediapy==1.2.2
9
+ fire==0.7.0
10
+ decord==0.6.0
11
+ OpenEXR==3.3.2
12
+ kornia==0.7.4
13
+ opencv-python==4.10.0.84
14
+ h5py==3.12.1
15
+ moderngl==5.12.0
16
+ piqp==0.4.2
third_party/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import sys
4
+
5
+ sys.path.append('third_party/moge')
6
+ from .moge.moge.model.moge_model import MoGeModel
7
+
8
+ class MoGe(nn.Module):
9
+
10
+ def __init__(self, cache_dir):
11
+ super().__init__()
12
+ self.model = MoGeModel.from_pretrained(
13
+ 'Ruicheng/moge-vitl', cache_dir=cache_dir).eval()
14
+
15
+
16
+ @torch.no_grad()
17
+ def forward_image(self, image: torch.Tensor, **kwargs):
18
+ # image: b, 3, h, w 0,1
19
+ output = self.model.infer(image, resolution_level=9, apply_mask=False, **kwargs)
20
+ points = output['points'] # b,h,w,3
21
+ masks = output['mask'] # b,h,w
22
+ return points, masks
third_party/moge ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit dd158c05461f2353287a182afb2adf0fda46436f
utils/__init__.py ADDED
File without changes
utils/disp_utils.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from matplotlib import cm
3
+
4
+ def robust_min_max(tensor, quantile=0.99):
5
+ T, H, W = tensor.shape
6
+ min_vals = []
7
+ max_vals = []
8
+ for i in range(T):
9
+ min_vals.append(torch.quantile(tensor[i], q=1-quantile, interpolation='nearest').item())
10
+ max_vals.append(torch.quantile(tensor[i], q=quantile, interpolation='nearest').item())
11
+ return min(min_vals), max(max_vals)
12
+
13
+
14
+ class ColorMapper:
15
+ def __init__(self, colormap: str = "inferno"):
16
+ self.colormap = torch.tensor(cm.get_cmap(colormap).colors)
17
+
18
+ def apply(self, image: torch.Tensor, v_min=None, v_max=None):
19
+ # assert len(image.shape) == 2
20
+ if v_min is None:
21
+ v_min = image.min()
22
+ if v_max is None:
23
+ v_max = image.max()
24
+ image = (image - v_min) / (v_max - v_min)
25
+ image = (image * 255).long()
26
+ colormap = self.colormap.to(image.device)
27
+ image = colormap[image]
28
+ return image
29
+
30
+ def color_video_disp(disp):
31
+ visualizer = ColorMapper()
32
+ disp_img = visualizer.apply(disp, v_min=0, v_max=1)
33
+ return disp_img
34
+
35
+ def pmap_to_disp(point_maps, valid_masks):
36
+ disp_map = 1.0 / (point_maps[..., 2] + 1e-4)
37
+ min_disparity, max_disparity = robust_min_max(disp_map)
38
+ disp_map = torch.clamp((disp_map - min_disparity) / (max_disparity - min_disparity+1e-4), 0, 1)
39
+
40
+ disp_map = color_video_disp(disp_map)
41
+ disp_map[~valid_masks] = 0
42
+ return disp_map
43
+ # imageio.mimsave(os.path.join(args.save_dir, os.path.basename(args.data[:-4])+'_disp.mp4'), disp, fps=24, quality=9, macro_block_size=1)
utils/glb_utils.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import trimesh
2
+ import numpy as np
3
+
4
+ def pmap_to_glb(point_map, valid_mask, frame) -> trimesh.Scene:
5
+
6
+
7
+ pts_3d = point_map[valid_mask] * np.array([-1, -1, 1])
8
+ pts_rgb = frame[valid_mask]
9
+
10
+ # Initialize a 3D scene
11
+ scene_3d = trimesh.Scene()
12
+
13
+ # Add point cloud data to the scene
14
+ point_cloud_data = trimesh.PointCloud(
15
+ vertices=pts_3d, colors=pts_rgb
16
+ )
17
+
18
+ scene_3d.add_geometry(point_cloud_data)
19
+ return scene_3d