Spaces:
Running
on
Zero
Running
on
Zero
slothfulxtx
commited on
Commit
·
ca2145e
1
Parent(s):
774e213
init
Browse files- .gitignore +178 -0
- .gitmodules +3 -0
- app.py +436 -0
- geometrycrafter/__init__.py +4 -0
- geometrycrafter/determ_ppl.py +453 -0
- geometrycrafter/diff_ppl.py +526 -0
- geometrycrafter/pmap_vae.py +330 -0
- geometrycrafter/unet.py +281 -0
- requirements.txt +16 -0
- third_party/__init__.py +22 -0
- third_party/moge +1 -0
- utils/__init__.py +0 -0
- utils/disp_utils.py +43 -0
- utils/glb_utils.py +19 -0
.gitignore
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
110 |
+
.pdm.toml
|
111 |
+
.pdm-python
|
112 |
+
.pdm-build/
|
113 |
+
|
114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
115 |
+
__pypackages__/
|
116 |
+
|
117 |
+
# Celery stuff
|
118 |
+
celerybeat-schedule
|
119 |
+
celerybeat.pid
|
120 |
+
|
121 |
+
# SageMath parsed files
|
122 |
+
*.sage.py
|
123 |
+
|
124 |
+
# Environments
|
125 |
+
.env
|
126 |
+
.venv
|
127 |
+
env/
|
128 |
+
venv/
|
129 |
+
ENV/
|
130 |
+
env.bak/
|
131 |
+
venv.bak/
|
132 |
+
|
133 |
+
# Spyder project settings
|
134 |
+
.spyderproject
|
135 |
+
.spyproject
|
136 |
+
|
137 |
+
# Rope project settings
|
138 |
+
.ropeproject
|
139 |
+
|
140 |
+
# mkdocs documentation
|
141 |
+
/site
|
142 |
+
|
143 |
+
# mypy
|
144 |
+
.mypy_cache/
|
145 |
+
.dmypy.json
|
146 |
+
dmypy.json
|
147 |
+
|
148 |
+
# Pyre type checker
|
149 |
+
.pyre/
|
150 |
+
|
151 |
+
# pytype static type analyzer
|
152 |
+
.pytype/
|
153 |
+
|
154 |
+
# Cython debug symbols
|
155 |
+
cython_debug/
|
156 |
+
|
157 |
+
# PyCharm
|
158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
+
#.idea/
|
163 |
+
|
164 |
+
# Vscode settings
|
165 |
+
.vscode/
|
166 |
+
|
167 |
+
# Temporal files
|
168 |
+
/tmp
|
169 |
+
tmp*
|
170 |
+
|
171 |
+
# Workspace
|
172 |
+
/workspace
|
173 |
+
|
174 |
+
# running scripts
|
175 |
+
/*.sh
|
176 |
+
|
177 |
+
# pretrained
|
178 |
+
/pretrained_models
|
.gitmodules
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "third_party/moge"]
|
2 |
+
path = third_party/moge
|
3 |
+
url = https://github.com/microsoft/MoGe.git
|
app.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import os
|
3 |
+
import uuid
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import spaces
|
8 |
+
import gradio as gr
|
9 |
+
import torch
|
10 |
+
from decord import cpu, VideoReader
|
11 |
+
from diffusers.training_utils import set_seed
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import imageio
|
14 |
+
from kornia.filters import canny
|
15 |
+
from kornia.morphology import dilation
|
16 |
+
|
17 |
+
from third_party import MoGe
|
18 |
+
from geometrycrafter import (
|
19 |
+
GeometryCrafterDiffPipeline,
|
20 |
+
GeometryCrafterDetermPipeline,
|
21 |
+
PMapAutoencoderKLTemporalDecoder,
|
22 |
+
UNetSpatioTemporalConditionModelVid2vid
|
23 |
+
)
|
24 |
+
|
25 |
+
from utils.glb_utils import pmap_to_glb
|
26 |
+
from utils.disp_utils import pmap_to_disp
|
27 |
+
|
28 |
+
examples = [
|
29 |
+
# process_length: int,
|
30 |
+
# max_res: int,
|
31 |
+
# num_inference_steps: int,
|
32 |
+
# guidance_scale: float,
|
33 |
+
# window_size: int,
|
34 |
+
# decode_chunk_size: int,
|
35 |
+
# overlap: int,
|
36 |
+
["examples/video1.mp4", 60, 640, 5, 1.0, 110, 8, 25],
|
37 |
+
["examples/video2.mp4", 60, 640, 5, 1.0, 110, 8, 25],
|
38 |
+
["examples/video3.mp4", 60, 640, 5, 1.0, 110, 8, 25],
|
39 |
+
["examples/video4.mp4", 60, 640, 5, 1.0, 110, 8, 25],
|
40 |
+
]
|
41 |
+
|
42 |
+
model_type = 'diff'
|
43 |
+
cache_dir = 'workspace/cache'
|
44 |
+
|
45 |
+
unet = UNetSpatioTemporalConditionModelVid2vid.from_pretrained(
|
46 |
+
'TencentARC/GeometryCrafter',
|
47 |
+
subfolder='unet_diff' if model_type == 'diff' else 'unet_determ',
|
48 |
+
low_cpu_mem_usage=True,
|
49 |
+
torch_dtype=torch.float16,
|
50 |
+
cache_dir=cache_dir
|
51 |
+
).requires_grad_(False).to("cuda", dtype=torch.float16)
|
52 |
+
point_map_vae = PMapAutoencoderKLTemporalDecoder.from_pretrained(
|
53 |
+
'TencentARC/GeometryCrafter',
|
54 |
+
subfolder='point_map_vae',
|
55 |
+
low_cpu_mem_usage=True,
|
56 |
+
torch_dtype=torch.float32,
|
57 |
+
cache_dir=cache_dir
|
58 |
+
).requires_grad_(False).to("cuda", dtype=torch.float32)
|
59 |
+
prior_model = MoGe(
|
60 |
+
cache_dir=cache_dir,
|
61 |
+
).requires_grad_(False).to('cuda', dtype=torch.float32)
|
62 |
+
if model_type == 'diff':
|
63 |
+
pipe = GeometryCrafterDiffPipeline.from_pretrained(
|
64 |
+
'stabilityai/stable-video-diffusion-img2vid-xt',
|
65 |
+
unet=unet,
|
66 |
+
torch_dtype=torch.float16,
|
67 |
+
variant="fp16",
|
68 |
+
cache_dir=cache_dir
|
69 |
+
).to("cuda")
|
70 |
+
else:
|
71 |
+
pipe = GeometryCrafterDetermPipeline.from_pretrained(
|
72 |
+
'stabilityai/stable-video-diffusion-img2vid-xt',
|
73 |
+
unet=unet,
|
74 |
+
torch_dtype=torch.float16,
|
75 |
+
variant="fp16",
|
76 |
+
cache_dir=cache_dir
|
77 |
+
).to("cuda")
|
78 |
+
|
79 |
+
try:
|
80 |
+
pipe.enable_xformers_memory_efficient_attention()
|
81 |
+
except Exception as e:
|
82 |
+
print(e)
|
83 |
+
print("Xformers is not enabled")
|
84 |
+
# bugs at https://github.com/continue-revolution/sd-webui-animatediff/issues/101
|
85 |
+
# pipe.enable_xformers_memory_efficient_attention()
|
86 |
+
pipe.enable_attention_slicing()
|
87 |
+
|
88 |
+
mesh_seqs = []
|
89 |
+
frame_seqs = []
|
90 |
+
cur_mesh_idx = None
|
91 |
+
|
92 |
+
def read_video_frames(video_path, process_length, max_res):
|
93 |
+
print("==> processing video: ", video_path)
|
94 |
+
vid = VideoReader(video_path, ctx=cpu(0))
|
95 |
+
fps = vid.get_avg_fps()
|
96 |
+
print("==> original video shape: ", (len(vid), *vid.get_batch([0]).shape[1:]))
|
97 |
+
original_height, original_width = vid.get_batch([0]).shape[1:3]
|
98 |
+
if max(original_height, original_width) > max_res:
|
99 |
+
scale = max_res / max(original_height, original_width)
|
100 |
+
original_height, original_width = round(original_height * scale), round(original_width * scale)
|
101 |
+
else:
|
102 |
+
scale = 1.0
|
103 |
+
height = round(original_height * scale / 64) * 64
|
104 |
+
width = round(original_width * scale / 64) * 64
|
105 |
+
vid = VideoReader(video_path, ctx=cpu(0), width=original_width, height=original_height)
|
106 |
+
frames_idx = list(range(0, min(len(vid), process_length) if process_length != -1 else len(vid)))
|
107 |
+
print(
|
108 |
+
f"==> final processing shape: {len(frames_idx), *vid.get_batch([0]).shape[1:]}"
|
109 |
+
)
|
110 |
+
frames = vid.get_batch(frames_idx).asnumpy().astype("float32") / 255.0
|
111 |
+
return frames, height, width, fps
|
112 |
+
|
113 |
+
|
114 |
+
def compute_edge_mask(depth: torch.Tensor, edge_dilation_radius: int):
|
115 |
+
magnitude, edges = canny(depth[None, None, :, :], low_threshold=0.4, high_threshold=0.5)
|
116 |
+
magnitude = magnitude[0, 0]
|
117 |
+
edges = edges[0, 0]
|
118 |
+
mask = (edges > 0).float()
|
119 |
+
mask = dilation(mask[None, None, :, :], torch.ones((edge_dilation_radius,edge_dilation_radius), device=mask.device))
|
120 |
+
return mask[0, 0] > 0.5
|
121 |
+
|
122 |
+
@spaces.GPU(duration=120)
|
123 |
+
@torch.inference_mode()
|
124 |
+
def infer_geometry(
|
125 |
+
video: str,
|
126 |
+
process_length: int,
|
127 |
+
max_res: int,
|
128 |
+
num_inference_steps: int,
|
129 |
+
guidance_scale: float,
|
130 |
+
window_size: int,
|
131 |
+
decode_chunk_size: int,
|
132 |
+
overlap: int,
|
133 |
+
downsample_ratio: float = 1.0, # downsample pcd for visualization
|
134 |
+
num_sample_frames: int =8, # downsample frames for visualization
|
135 |
+
remove_edge: bool = True, # remove edge for visualization
|
136 |
+
save_folder: str = os.path.join('workspace', 'GeometryCrafterApp'),
|
137 |
+
):
|
138 |
+
try:
|
139 |
+
global cur_mesh_idx, mesh_seqs, frame_seqs
|
140 |
+
run_id = str(uuid.uuid4())
|
141 |
+
set_seed(42)
|
142 |
+
pipe.enable_xformers_memory_efficient_attention()
|
143 |
+
|
144 |
+
frames, height, width, fps = read_video_frames(video, process_length, max_res)
|
145 |
+
aspect_ratio = width / height
|
146 |
+
assert 0.5 <= aspect_ratio and aspect_ratio <= 2.0
|
147 |
+
frames_tensor = torch.tensor(frames.astype("float32"), device='cuda').float().permute(0, 3, 1, 2)
|
148 |
+
window_size = min(window_size, len(frames))
|
149 |
+
if window_size == len(frames):
|
150 |
+
overlap = 0
|
151 |
+
|
152 |
+
point_maps, valid_masks = pipe(
|
153 |
+
frames_tensor,
|
154 |
+
point_map_vae,
|
155 |
+
prior_model,
|
156 |
+
height=height,
|
157 |
+
width=width,
|
158 |
+
num_inference_steps=num_inference_steps,
|
159 |
+
guidance_scale=guidance_scale,
|
160 |
+
window_size=window_size,
|
161 |
+
decode_chunk_size=decode_chunk_size,
|
162 |
+
overlap=overlap,
|
163 |
+
force_projection=True,
|
164 |
+
force_fixed_focal=True,
|
165 |
+
)
|
166 |
+
frames_tensor = frames_tensor.cpu()
|
167 |
+
point_maps = point_maps.cpu()
|
168 |
+
valid_masks = valid_masks.cpu()
|
169 |
+
|
170 |
+
gc.collect()
|
171 |
+
torch.cuda.empty_cache()
|
172 |
+
output_npz_path = Path(save_folder, run_id, f'point_maps.npz')
|
173 |
+
output_npz_path.parent.mkdir(exist_ok=True)
|
174 |
+
|
175 |
+
|
176 |
+
np.savez_compressed(
|
177 |
+
output_npz_path,
|
178 |
+
point_map=point_maps.cpu().numpy().astype(np.float16),
|
179 |
+
valid_mask=valid_masks.cpu().numpy().astype(np.bool_)
|
180 |
+
)
|
181 |
+
|
182 |
+
output_disp_path = Path(save_folder, run_id, f'disp.mp4')
|
183 |
+
output_disp_path.parent.mkdir(exist_ok=True)
|
184 |
+
|
185 |
+
colored_disp = pmap_to_disp(point_maps, valid_masks)
|
186 |
+
imageio.mimsave(
|
187 |
+
output_disp_path, (colored_disp*255).cpu().numpy().astype(np.uint8), fps=fps, macro_block_size=1)
|
188 |
+
|
189 |
+
|
190 |
+
# downsample for visualization
|
191 |
+
if downsample_ratio > 1.0:
|
192 |
+
H, W = point_maps.shape[1:3]
|
193 |
+
H, W = round(H / downsample_ratio), round(W / downsample_ratio)
|
194 |
+
point_maps = F.interpolate(point_maps.permute(0,3,1,2), (H, W)).permute(0,2,3,1)
|
195 |
+
frames = F.interpolate(frames_tensor, (H, W)).permute(0,2,3,1)
|
196 |
+
valid_masks = F.interpolate(valid_masks.float()[:, None], (H, W))[:, 0] > 0.5
|
197 |
+
else:
|
198 |
+
H, W = point_maps.shape[1:3]
|
199 |
+
frames = frames_tensor.permute(0,2,3,1)
|
200 |
+
|
201 |
+
|
202 |
+
if remove_edge:
|
203 |
+
for i in range(len(valid_masks)):
|
204 |
+
edge_mask = compute_edge_mask(point_maps[i, :, :, 2], 3)
|
205 |
+
valid_masks[i] = valid_masks[i] & (~edge_mask)
|
206 |
+
|
207 |
+
indices = np.linspace(0, len(point_maps)-1, num_sample_frames)
|
208 |
+
indices = np.round(indices).astype(np.int32)
|
209 |
+
|
210 |
+
mesh_seqs.clear()
|
211 |
+
cur_mesh_idx = None
|
212 |
+
|
213 |
+
for index in indices:
|
214 |
+
|
215 |
+
valid_mask = valid_masks[index].cpu().numpy()
|
216 |
+
point_map = point_maps[index].cpu().numpy()
|
217 |
+
frame = frames[index].cpu().numpy()
|
218 |
+
output_glb_path = Path(save_folder, run_id, f'{index:04}.glb')
|
219 |
+
output_glb_path.parent.mkdir(exist_ok=True)
|
220 |
+
glbscene = pmap_to_glb(point_map, valid_mask, frame)
|
221 |
+
glbscene.export(file_obj=output_glb_path)
|
222 |
+
mesh_seqs.append(output_glb_path)
|
223 |
+
frame_seqs.append(index)
|
224 |
+
|
225 |
+
cur_mesh_idx = 0
|
226 |
+
|
227 |
+
gc.collect()
|
228 |
+
torch.cuda.empty_cache()
|
229 |
+
|
230 |
+
return [
|
231 |
+
gr.Model3D(value=mesh_seqs[cur_mesh_idx], label=f"Frame: {frame_seqs[cur_mesh_idx]}"),
|
232 |
+
gr.Video(value=output_disp_path, label="Disparity", interactive=False),
|
233 |
+
gr.DownloadButton("Download Npz File", value=output_npz_path, visible=True)
|
234 |
+
]
|
235 |
+
except Exception as e:
|
236 |
+
mesh_seqs.clear()
|
237 |
+
frame_seqs.clear()
|
238 |
+
cur_mesh_idx = None
|
239 |
+
gc.collect()
|
240 |
+
torch.cuda.empty_cache()
|
241 |
+
raise gr.Error(str(e))
|
242 |
+
# return [
|
243 |
+
# gr.Model3D(
|
244 |
+
# label="Point Map",
|
245 |
+
# clear_color=[1.0, 1.0, 1.0, 1.0],
|
246 |
+
# interactive=False
|
247 |
+
# ),
|
248 |
+
# gr.Video(label="Disparity", interactive=False),
|
249 |
+
# gr.DownloadButton("Download Npz File", visible=False)
|
250 |
+
# ]
|
251 |
+
|
252 |
+
def goto_prev_frame():
|
253 |
+
global cur_mesh_idx, mesh_seqs, frame_seqs
|
254 |
+
if cur_mesh_idx is not None and len(mesh_seqs) > 0:
|
255 |
+
if cur_mesh_idx > 0:
|
256 |
+
cur_mesh_idx -= 1
|
257 |
+
return gr.Model3D(value=mesh_seqs[cur_mesh_idx], label=f"Frame: {frame_seqs[cur_mesh_idx]}")
|
258 |
+
|
259 |
+
|
260 |
+
def goto_next_frame():
|
261 |
+
global cur_mesh_idx, mesh_seqs, frame_seqs
|
262 |
+
if cur_mesh_idx is not None and len(mesh_seqs) > 0:
|
263 |
+
if cur_mesh_idx < len(mesh_seqs)-1:
|
264 |
+
cur_mesh_idx += 1
|
265 |
+
return gr.Model3D(value=mesh_seqs[cur_mesh_idx], label=f"Frame: {frame_seqs[cur_mesh_idx]}")
|
266 |
+
|
267 |
+
def download_file():
|
268 |
+
return gr.DownloadButton(visible=False)
|
269 |
+
|
270 |
+
def build_demo():
|
271 |
+
with gr.Blocks(analytics_enabled=False) as gradio_demo:
|
272 |
+
gr.Markdown(
|
273 |
+
"""
|
274 |
+
<div align='center'>
|
275 |
+
<h1> GeometryCrafter: Consistent Geometry Estimation for Open-world Videos with Diffusion Priors </h1> \
|
276 |
+
<h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
|
277 |
+
<a href='https://scholar.google.com/citations?user=zHp0rMIAAAAJ'>Tian-Xing Xu</a>, \
|
278 |
+
<a href='https://scholar.google.com/citations?user=qgdesEcAAAAJ'>Xiangjun Gao</a>, \
|
279 |
+
<a href='https://wbhu.github.io'>Wenbo Hu</a>, \
|
280 |
+
<a href='https://xiaoyu258.github.io/'>Xiaoyu Li</a>, \
|
281 |
+
<a href='https://scholar.google.com/citations?user=AWtV-EQAAAAJ'>Song-Hai Zhang</a>,\
|
282 |
+
<a href='https://scholar.google.com/citations?user=4oXBp9UAAAAJ'>Ying Shan</a>\
|
283 |
+
</h2> \
|
284 |
+
<span style='font-size:18px'>If you find GeometryCrafter useful, please help ⭐ the \
|
285 |
+
<a style='font-size:18px' href='https://github.com/TencentARC/GeometryCrafter/'>[Github Repo]</a>\
|
286 |
+
, which is important to Open-Source projects. Thanks!\
|
287 |
+
<a style='font-size:18px' href='https://arxiv.org'> [ArXivTODO] </a>\
|
288 |
+
<a style='font-size:18px' href='https://geometrycrafter.github.io'> [Project Page] </a>
|
289 |
+
</span>
|
290 |
+
</div>
|
291 |
+
"""
|
292 |
+
)
|
293 |
+
|
294 |
+
with gr.Row(equal_height=True):
|
295 |
+
with gr.Column(scale=1):
|
296 |
+
input_video = gr.Video(
|
297 |
+
label="Input Video",
|
298 |
+
sources=['upload']
|
299 |
+
)
|
300 |
+
with gr.Row(equal_height=False):
|
301 |
+
with gr.Accordion("Advanced Settings", open=False):
|
302 |
+
process_length = gr.Slider(
|
303 |
+
label="process length",
|
304 |
+
minimum=-1,
|
305 |
+
maximum=280,
|
306 |
+
value=110,
|
307 |
+
step=1,
|
308 |
+
)
|
309 |
+
max_res = gr.Slider(
|
310 |
+
label="max resolution",
|
311 |
+
minimum=512,
|
312 |
+
maximum=2048,
|
313 |
+
value=1024,
|
314 |
+
step=64,
|
315 |
+
)
|
316 |
+
num_denoising_steps = gr.Slider(
|
317 |
+
label="num denoising steps",
|
318 |
+
minimum=1,
|
319 |
+
maximum=25,
|
320 |
+
value=5,
|
321 |
+
step=1,
|
322 |
+
)
|
323 |
+
guidance_scale = gr.Slider(
|
324 |
+
label="cfg scale",
|
325 |
+
minimum=1.0,
|
326 |
+
maximum=1.2,
|
327 |
+
value=1.0,
|
328 |
+
step=0.1,
|
329 |
+
)
|
330 |
+
window_size = gr.Slider(
|
331 |
+
label="shift window size",
|
332 |
+
minimum=10,
|
333 |
+
maximum=110,
|
334 |
+
value=110,
|
335 |
+
step=10,
|
336 |
+
)
|
337 |
+
decode_chunk_size = gr.Slider(
|
338 |
+
label="decode chunk size",
|
339 |
+
minimum=1,
|
340 |
+
maximum=16,
|
341 |
+
value=6,
|
342 |
+
step=1,
|
343 |
+
)
|
344 |
+
overlap = gr.Slider(
|
345 |
+
label="overlap",
|
346 |
+
minimum=1,
|
347 |
+
maximum=50,
|
348 |
+
value=25,
|
349 |
+
step=1,
|
350 |
+
)
|
351 |
+
generate_btn = gr.Button("Generate")
|
352 |
+
|
353 |
+
with gr.Column(scale=1):
|
354 |
+
output_point_maps = gr.Model3D(
|
355 |
+
label="Point Map",
|
356 |
+
clear_color=[1.0, 1.0, 1.0, 1.0],
|
357 |
+
# display_mode="solid"
|
358 |
+
interactive=False
|
359 |
+
)
|
360 |
+
with gr.Row():
|
361 |
+
prev_btn = gr.Button("Prev")
|
362 |
+
next_btn = gr.Button("Next")
|
363 |
+
|
364 |
+
with gr.Column(scale=1):
|
365 |
+
output_disp_video = gr.Video(
|
366 |
+
label="Disparity",
|
367 |
+
interactive=False
|
368 |
+
)
|
369 |
+
download_btn = gr.DownloadButton("Download Npz File", visible=False)
|
370 |
+
|
371 |
+
gr.Examples(
|
372 |
+
examples=examples,
|
373 |
+
fn=infer_geometry,
|
374 |
+
inputs=[
|
375 |
+
input_video,
|
376 |
+
process_length,
|
377 |
+
max_res,
|
378 |
+
num_denoising_steps,
|
379 |
+
guidance_scale,
|
380 |
+
window_size,
|
381 |
+
decode_chunk_size,
|
382 |
+
overlap,
|
383 |
+
],
|
384 |
+
outputs=[output_point_maps, output_disp_video, download_btn],
|
385 |
+
# cache_examples="lazy",
|
386 |
+
)
|
387 |
+
gr.Markdown(
|
388 |
+
"""
|
389 |
+
<span style='font-size:18px'>Note:
|
390 |
+
For time quota consideration, we set the default parameters to be more efficient here,
|
391 |
+
with a trade-off of shorter video length and slightly lower quality.
|
392 |
+
You may adjust the parameters according to our
|
393 |
+
<a style='font-size:18px' href='https://github.com/TencentARC/GeometryCrafter/'>[Github Repo]</a>
|
394 |
+
for better results if you have enough time quota. We only provide a simplified visualization
|
395 |
+
script in this page due to the lack of support for point cloud sequences. You can download
|
396 |
+
the npz file and open it with Viser backend in our repo for better visualization.
|
397 |
+
</span>
|
398 |
+
"""
|
399 |
+
)
|
400 |
+
|
401 |
+
generate_btn.click(
|
402 |
+
fn=infer_geometry,
|
403 |
+
inputs=[
|
404 |
+
input_video,
|
405 |
+
process_length,
|
406 |
+
max_res,
|
407 |
+
num_denoising_steps,
|
408 |
+
guidance_scale,
|
409 |
+
window_size,
|
410 |
+
decode_chunk_size,
|
411 |
+
overlap,
|
412 |
+
],
|
413 |
+
outputs=[output_point_maps, output_disp_video, download_btn],
|
414 |
+
)
|
415 |
+
|
416 |
+
prev_btn.click(
|
417 |
+
fn=goto_prev_frame,
|
418 |
+
outputs=output_point_maps,
|
419 |
+
)
|
420 |
+
next_btn.click(
|
421 |
+
fn=goto_next_frame,
|
422 |
+
outputs=output_point_maps,
|
423 |
+
)
|
424 |
+
download_btn.click(
|
425 |
+
fn=download_file,
|
426 |
+
outputs=download_btn
|
427 |
+
)
|
428 |
+
|
429 |
+
return gradio_demo
|
430 |
+
|
431 |
+
|
432 |
+
if __name__ == "__main__":
|
433 |
+
demo = build_demo()
|
434 |
+
demo.queue()
|
435 |
+
demo.launch(server_name="0.0.0.0", server_port=12345, debug=True, share=False)
|
436 |
+
# demo.launch(share=True)
|
geometrycrafter/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .pmap_vae import PMapAutoencoderKLTemporalDecoder
|
2 |
+
from .unet import UNetSpatioTemporalConditionModelVid2vid
|
3 |
+
from .diff_ppl import GeometryCrafterDiffPipeline
|
4 |
+
from .determ_ppl import GeometryCrafterDetermPipeline
|
geometrycrafter/determ_ppl.py
ADDED
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union
|
2 |
+
import gc
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
|
9 |
+
_resize_with_antialiasing,
|
10 |
+
StableVideoDiffusionPipeline,
|
11 |
+
)
|
12 |
+
from diffusers.utils import logging
|
13 |
+
from kornia.utils import create_meshgrid
|
14 |
+
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
15 |
+
|
16 |
+
|
17 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
18 |
+
|
19 |
+
@torch.no_grad()
|
20 |
+
def normalize_point_map(point_map, valid_mask):
|
21 |
+
# T,H,W,3 T,H,W
|
22 |
+
norm_factor = (point_map[..., 2] * valid_mask.float()).mean() / (valid_mask.float().mean() + 1e-8)
|
23 |
+
norm_factor = norm_factor.clip(min=1e-3)
|
24 |
+
return point_map / norm_factor
|
25 |
+
|
26 |
+
def point_map_xy2intrinsic_map(point_map_xy):
|
27 |
+
# *,h,w,2
|
28 |
+
height, width = point_map_xy.shape[-3], point_map_xy.shape[-2]
|
29 |
+
assert height % 2 == 0
|
30 |
+
assert width % 2 == 0
|
31 |
+
mesh_grid = create_meshgrid(
|
32 |
+
height=height,
|
33 |
+
width=width,
|
34 |
+
normalized_coordinates=True,
|
35 |
+
device=point_map_xy.device,
|
36 |
+
dtype=point_map_xy.dtype
|
37 |
+
)[0] # h,w,2
|
38 |
+
assert mesh_grid.abs().min() > 1e-4
|
39 |
+
# *,h,w,2
|
40 |
+
mesh_grid = mesh_grid.expand_as(point_map_xy)
|
41 |
+
nc = point_map_xy.mean(dim=-2).mean(dim=-2) # *, 2
|
42 |
+
nc_map = nc[..., None, None, :].expand_as(point_map_xy)
|
43 |
+
nf = ((point_map_xy - nc_map) / mesh_grid).mean(dim=-2).mean(dim=-2)
|
44 |
+
nf_map = nf[..., None, None, :].expand_as(point_map_xy)
|
45 |
+
# print((mesh_grid * nf_map + nc_map - point_map_xy).abs().max())
|
46 |
+
|
47 |
+
return torch.cat([nc_map, nf_map], dim=-1)
|
48 |
+
|
49 |
+
def robust_min_max(tensor, quantile=0.99):
|
50 |
+
T, H, W = tensor.shape
|
51 |
+
min_vals = []
|
52 |
+
max_vals = []
|
53 |
+
for i in range(T):
|
54 |
+
min_vals.append(torch.quantile(tensor[i], q=1-quantile, interpolation='nearest').item())
|
55 |
+
max_vals.append(torch.quantile(tensor[i], q=quantile, interpolation='nearest').item())
|
56 |
+
return min(min_vals), max(max_vals)
|
57 |
+
|
58 |
+
class GeometryCrafterDetermPipeline(StableVideoDiffusionPipeline):
|
59 |
+
|
60 |
+
@torch.inference_mode()
|
61 |
+
def encode_video(
|
62 |
+
self,
|
63 |
+
video: torch.Tensor,
|
64 |
+
chunk_size: int = 14,
|
65 |
+
) -> torch.Tensor:
|
66 |
+
"""
|
67 |
+
:param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
|
68 |
+
:param chunk_size: the chunk size to encode video
|
69 |
+
:return: image_embeddings in shape of [b, 1024]
|
70 |
+
"""
|
71 |
+
|
72 |
+
video_224 = _resize_with_antialiasing(video.float(), (224, 224))
|
73 |
+
video_224 = (video_224 + 1.0) / 2.0 # [-1, 1] -> [0, 1]
|
74 |
+
embeddings = []
|
75 |
+
for i in range(0, video_224.shape[0], chunk_size):
|
76 |
+
emb = self.feature_extractor(
|
77 |
+
images=video_224[i : i + chunk_size],
|
78 |
+
do_normalize=True,
|
79 |
+
do_center_crop=False,
|
80 |
+
do_resize=False,
|
81 |
+
do_rescale=False,
|
82 |
+
return_tensors="pt",
|
83 |
+
).pixel_values.to(video.device, dtype=video.dtype)
|
84 |
+
embeddings.append(self.image_encoder(emb).image_embeds) # [b, 1024]
|
85 |
+
|
86 |
+
embeddings = torch.cat(embeddings, dim=0) # [t, 1024]
|
87 |
+
return embeddings
|
88 |
+
|
89 |
+
@torch.inference_mode()
|
90 |
+
def encode_vae_video(
|
91 |
+
self,
|
92 |
+
video: torch.Tensor,
|
93 |
+
chunk_size: int = 14,
|
94 |
+
):
|
95 |
+
"""
|
96 |
+
:param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
|
97 |
+
:param chunk_size: the chunk size to encode video
|
98 |
+
:return: vae latents in shape of [b, c, h, w]
|
99 |
+
"""
|
100 |
+
video_latents = []
|
101 |
+
for i in range(0, video.shape[0], chunk_size):
|
102 |
+
video_latents.append(
|
103 |
+
self.vae.encode(video[i : i + chunk_size]).latent_dist.mode()
|
104 |
+
)
|
105 |
+
video_latents = torch.cat(video_latents, dim=0)
|
106 |
+
return video_latents
|
107 |
+
|
108 |
+
|
109 |
+
@torch.inference_mode()
|
110 |
+
def produce_priors(self, prior_model, frame, chunk_size=8):
|
111 |
+
T, _, H, W = frame.shape
|
112 |
+
frame = (frame + 1) / 2
|
113 |
+
pred_point_maps = []
|
114 |
+
pred_masks = []
|
115 |
+
for i in range(0, len(frame), chunk_size):
|
116 |
+
pred_p, pred_m = prior_model.forward_image(frame[i:i+chunk_size])
|
117 |
+
pred_point_maps.append(pred_p)
|
118 |
+
pred_masks.append(pred_m)
|
119 |
+
pred_point_maps = torch.cat(pred_point_maps, dim=0)
|
120 |
+
pred_masks = torch.cat(pred_masks, dim=0)
|
121 |
+
|
122 |
+
pred_masks = pred_masks.float() * 2 - 1
|
123 |
+
|
124 |
+
# T,H,W,3 T,H,W
|
125 |
+
pred_point_maps = normalize_point_map(pred_point_maps, pred_masks > 0)
|
126 |
+
|
127 |
+
pred_disps = 1.0 / pred_point_maps[..., 2].clamp_min(1e-3)
|
128 |
+
pred_disps = pred_disps * (pred_masks > 0)
|
129 |
+
min_disparity, max_disparity = robust_min_max(pred_disps)
|
130 |
+
pred_disps = ((pred_disps - min_disparity) / (max_disparity - min_disparity+1e-4)).clamp(0, 1)
|
131 |
+
pred_disps = pred_disps * 2 - 1
|
132 |
+
|
133 |
+
pred_point_maps[..., :2] = pred_point_maps[..., :2] / (pred_point_maps[..., 2:3] + 1e-7)
|
134 |
+
pred_point_maps[..., 2] = torch.log(pred_point_maps[..., 2] + 1e-7) * (pred_masks > 0) # [x/z, y/z, log(z)]
|
135 |
+
|
136 |
+
pred_intr_maps = point_map_xy2intrinsic_map(pred_point_maps[..., :2]).permute(0,3,1,2) # T,H,W,2
|
137 |
+
pred_point_maps = pred_point_maps.permute(0,3,1,2)
|
138 |
+
|
139 |
+
return pred_disps, pred_masks, pred_point_maps, pred_intr_maps
|
140 |
+
|
141 |
+
@torch.inference_mode()
|
142 |
+
def encode_point_map(self, point_map_vae, disparity, valid_mask, point_map, intrinsic_map, chunk_size=8):
|
143 |
+
T, _, H, W = point_map.shape
|
144 |
+
latents = []
|
145 |
+
|
146 |
+
psedo_image = disparity[:, None].repeat(1,3,1,1)
|
147 |
+
intrinsic_map = torch.norm(intrinsic_map[:, 2:4], p=2, dim=1, keepdim=False)
|
148 |
+
|
149 |
+
for i in range(0, T, chunk_size):
|
150 |
+
latent_dist = self.vae.encode(psedo_image[i : i + chunk_size].to(self.vae.dtype)).latent_dist
|
151 |
+
latent_dist = point_map_vae.encode(
|
152 |
+
torch.cat([
|
153 |
+
intrinsic_map[i:i+chunk_size, None],
|
154 |
+
point_map[i:i+chunk_size, 2:3],
|
155 |
+
disparity[i:i+chunk_size, None],
|
156 |
+
valid_mask[i:i+chunk_size, None]], dim=1),
|
157 |
+
latent_dist
|
158 |
+
)
|
159 |
+
if isinstance(latent_dist, DiagonalGaussianDistribution):
|
160 |
+
latent = latent_dist.mode()
|
161 |
+
else:
|
162 |
+
latent = latent_dist
|
163 |
+
|
164 |
+
assert isinstance(latent, torch.Tensor)
|
165 |
+
latents.append(latent)
|
166 |
+
latents = torch.cat(latents, dim=0)
|
167 |
+
latents = latents * self.vae.config.scaling_factor
|
168 |
+
return latents
|
169 |
+
|
170 |
+
@torch.no_grad()
|
171 |
+
def decode_point_map(self, point_map_vae, latents, chunk_size=8, force_projection=True, force_fixed_focal=True, use_extract_interp=False, need_resize=False, height=None, width=None):
|
172 |
+
T = latents.shape[0]
|
173 |
+
rec_intrinsic_maps = []
|
174 |
+
rec_depth_maps = []
|
175 |
+
rec_valid_masks = []
|
176 |
+
for i in range(0, T, chunk_size):
|
177 |
+
lat = latents[i:i+chunk_size]
|
178 |
+
rec_imap, rec_dmap, rec_vmask = point_map_vae.decode(
|
179 |
+
lat,
|
180 |
+
num_frames=lat.shape[0],
|
181 |
+
)
|
182 |
+
rec_intrinsic_maps.append(rec_imap)
|
183 |
+
rec_depth_maps.append(rec_dmap)
|
184 |
+
rec_valid_masks.append(rec_vmask)
|
185 |
+
|
186 |
+
rec_intrinsic_maps = torch.cat(rec_intrinsic_maps, dim=0)
|
187 |
+
rec_depth_maps = torch.cat(rec_depth_maps, dim=0)
|
188 |
+
rec_valid_masks = torch.cat(rec_valid_masks, dim=0)
|
189 |
+
|
190 |
+
if need_resize:
|
191 |
+
rec_depth_maps = F.interpolate(rec_depth_maps, (height, width), mode='nearest-exact') if use_extract_interp else F.interpolate(rec_depth_maps, (height, width), mode='bilinear', align_corners=False)
|
192 |
+
rec_valid_masks = F.interpolate(rec_valid_masks, (height, width), mode='nearest-exact') if use_extract_interp else F.interpolate(rec_valid_masks, (height, width), mode='bilinear', align_corners=False)
|
193 |
+
rec_intrinsic_maps = F.interpolate(rec_intrinsic_maps, (height, width), mode='bilinear', align_corners=False)
|
194 |
+
|
195 |
+
H, W = rec_intrinsic_maps.shape[-2], rec_intrinsic_maps.shape[-1]
|
196 |
+
mesh_grid = create_meshgrid(
|
197 |
+
H, W,
|
198 |
+
normalized_coordinates=True
|
199 |
+
).to(rec_intrinsic_maps.device, rec_intrinsic_maps.dtype, non_blocking=True)
|
200 |
+
# 1,h,w,2
|
201 |
+
rec_intrinsic_maps = torch.cat([rec_intrinsic_maps * W / np.sqrt(W**2+H**2), rec_intrinsic_maps * H / np.sqrt(W**2+H**2)], dim=1) # t,2,h,w
|
202 |
+
mesh_grid = mesh_grid.permute(0,3,1,2)
|
203 |
+
rec_valid_masks = rec_valid_masks.squeeze(1) > 0
|
204 |
+
|
205 |
+
if force_projection:
|
206 |
+
if force_fixed_focal:
|
207 |
+
nfx = (rec_intrinsic_maps[:, 0, :, :] * rec_valid_masks.float()).mean() / (rec_valid_masks.float().mean() + 1e-4)
|
208 |
+
nfy = (rec_intrinsic_maps[:, 1, :, :] * rec_valid_masks.float()).mean() / (rec_valid_masks.float().mean() + 1e-4)
|
209 |
+
rec_intrinsic_maps = torch.tensor([nfx, nfy], device=rec_intrinsic_maps.device)[None, :, None, None].repeat(T, 1, 1, 1)
|
210 |
+
else:
|
211 |
+
nfx = (rec_intrinsic_maps[:, 0, :, :] * rec_valid_masks.float()).mean(dim=[-1, -2]) / (rec_valid_masks.float().mean(dim=[-1, -2]) + 1e-4)
|
212 |
+
nfy = (rec_intrinsic_maps[:, 1, :, :] * rec_valid_masks.float()).mean(dim=[-1, -2]) / (rec_valid_masks.float().mean(dim=[-1, -2]) + 1e-4)
|
213 |
+
rec_intrinsic_maps = torch.stack([nfx, nfy], dim=-1)[:, :, None, None]
|
214 |
+
# t,2,1,1
|
215 |
+
|
216 |
+
rec_point_maps = torch.cat([rec_intrinsic_maps * mesh_grid, rec_depth_maps], dim=1).permute(0,2,3,1)
|
217 |
+
xy, z = rec_point_maps.split([2, 1], dim=-1)
|
218 |
+
z = torch.clamp_max(z, 10) # for numerical stability
|
219 |
+
z = torch.exp(z)
|
220 |
+
rec_point_maps = torch.cat([xy * z, z], dim=-1)
|
221 |
+
|
222 |
+
return rec_point_maps, rec_valid_masks
|
223 |
+
|
224 |
+
|
225 |
+
@torch.no_grad()
|
226 |
+
def __call__(
|
227 |
+
self,
|
228 |
+
video: Union[np.ndarray, torch.Tensor],
|
229 |
+
point_map_vae,
|
230 |
+
prior_model,
|
231 |
+
height: int = 576,
|
232 |
+
width: int = 1024,
|
233 |
+
window_size: Optional[int] = 14,
|
234 |
+
noise_aug_strength: float = 0.02,
|
235 |
+
decode_chunk_size: Optional[int] = None,
|
236 |
+
overlap: int = 4,
|
237 |
+
force_projection: bool = True,
|
238 |
+
force_fixed_focal: bool = True,
|
239 |
+
use_extract_interp: bool = False,
|
240 |
+
track_time: bool = False,
|
241 |
+
**kwargs
|
242 |
+
):
|
243 |
+
# video: in shape [t, h, w, c] if np.ndarray or [t, c, h, w] if torch.Tensor, in range [0, 1]
|
244 |
+
|
245 |
+
# 0. Define height and width for preprocessing
|
246 |
+
|
247 |
+
if isinstance(video, np.ndarray):
|
248 |
+
video = torch.from_numpy(video.transpose(0, 3, 1, 2))
|
249 |
+
else:
|
250 |
+
assert isinstance(video, torch.Tensor)
|
251 |
+
|
252 |
+
height = height or video.shape[-2]
|
253 |
+
width = width or video.shape[-1]
|
254 |
+
original_height = video.shape[-2]
|
255 |
+
original_width = video.shape[-1]
|
256 |
+
num_frames = video.shape[0]
|
257 |
+
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 8
|
258 |
+
if num_frames <= window_size:
|
259 |
+
window_size = num_frames
|
260 |
+
overlap = 0
|
261 |
+
stride = window_size - overlap
|
262 |
+
|
263 |
+
# 1. Check inputs. Raise error if not correct
|
264 |
+
assert height % 64 == 0 and width % 64 == 0
|
265 |
+
if original_height != height or original_width != width:
|
266 |
+
need_resize = True
|
267 |
+
else:
|
268 |
+
need_resize = False
|
269 |
+
|
270 |
+
# 2. Define call parameters
|
271 |
+
batch_size = 1
|
272 |
+
device = self._execution_device
|
273 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
274 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
275 |
+
# corresponds to doing no classifier free guidance.
|
276 |
+
self._guidance_scale = 1.0
|
277 |
+
|
278 |
+
if track_time:
|
279 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
280 |
+
prior_event = torch.cuda.Event(enable_timing=True)
|
281 |
+
encode_event = torch.cuda.Event(enable_timing=True)
|
282 |
+
denoise_event = torch.cuda.Event(enable_timing=True)
|
283 |
+
decode_event = torch.cuda.Event(enable_timing=True)
|
284 |
+
start_event.record()
|
285 |
+
|
286 |
+
# 3. Compute prior latents under original resolutions
|
287 |
+
pred_disparity, pred_valid_mask, pred_point_map, pred_intrinsic_map = self.produce_priors(
|
288 |
+
prior_model,
|
289 |
+
video.to(device=device, dtype=torch.float32),
|
290 |
+
chunk_size=decode_chunk_size
|
291 |
+
) # T,H,W T,H,W T,3,H,W T,2,H,W
|
292 |
+
|
293 |
+
if need_resize:
|
294 |
+
pred_disparity = F.interpolate(pred_disparity.unsqueeze(1), (height, width), mode='bilinear', align_corners=False).squeeze(1)
|
295 |
+
pred_valid_mask = F.interpolate(pred_valid_mask.unsqueeze(1), (height, width), mode='bilinear', align_corners=False).squeeze(1)
|
296 |
+
pred_point_map = F.interpolate(pred_point_map, (height, width), mode='bilinear', align_corners=False)
|
297 |
+
pred_intrinsic_map = F.interpolate(pred_intrinsic_map, (height, width), mode='bilinear', align_corners=False)
|
298 |
+
|
299 |
+
if track_time:
|
300 |
+
prior_event.record()
|
301 |
+
torch.cuda.synchronize()
|
302 |
+
elapsed_time_ms = start_event.elapsed_time(prior_event)
|
303 |
+
print(f"Elapsed time for computing per-frame prior: {elapsed_time_ms} ms")
|
304 |
+
else:
|
305 |
+
gc.collect()
|
306 |
+
torch.cuda.empty_cache()
|
307 |
+
|
308 |
+
|
309 |
+
|
310 |
+
# 3. Encode input video
|
311 |
+
if need_resize:
|
312 |
+
video = F.interpolate(video, (height, width), mode="bicubic", align_corners=False, antialias=True).clamp(0, 1)
|
313 |
+
|
314 |
+
video = video.to(device=device, dtype=self.dtype)
|
315 |
+
video = video * 2.0 - 1.0 # [0,1] -> [-1,1], in [t, c, h, w]
|
316 |
+
|
317 |
+
|
318 |
+
video_embeddings = self.encode_video(video, chunk_size=decode_chunk_size).unsqueeze(0)
|
319 |
+
|
320 |
+
prior_latents = self.encode_point_map(
|
321 |
+
point_map_vae,
|
322 |
+
pred_disparity,
|
323 |
+
pred_valid_mask,
|
324 |
+
pred_point_map,
|
325 |
+
pred_intrinsic_map,
|
326 |
+
chunk_size=decode_chunk_size
|
327 |
+
).unsqueeze(0).to(video_embeddings.dtype) # 1,T,C,H,W
|
328 |
+
|
329 |
+
|
330 |
+
# 4. Encode input image using VAE
|
331 |
+
|
332 |
+
# pdb.set_trace()
|
333 |
+
needs_upcasting = (
|
334 |
+
self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
335 |
+
)
|
336 |
+
if needs_upcasting:
|
337 |
+
self.vae.to(dtype=torch.float32)
|
338 |
+
|
339 |
+
video_latents = self.encode_vae_video(
|
340 |
+
video.to(self.vae.dtype),
|
341 |
+
chunk_size=decode_chunk_size,
|
342 |
+
).unsqueeze(0).to(video_embeddings.dtype) # [1, t, c, h, w]
|
343 |
+
|
344 |
+
|
345 |
+
if track_time:
|
346 |
+
encode_event.record()
|
347 |
+
torch.cuda.synchronize()
|
348 |
+
elapsed_time_ms = prior_event.elapsed_time(encode_event)
|
349 |
+
print(f"Elapsed time for encode prior and frames: {elapsed_time_ms} ms")
|
350 |
+
else:
|
351 |
+
gc.collect()
|
352 |
+
torch.cuda.empty_cache()
|
353 |
+
|
354 |
+
# cast back to fp16 if needed
|
355 |
+
if needs_upcasting:
|
356 |
+
self.vae.to(dtype=torch.float16)
|
357 |
+
|
358 |
+
# 5. Get Added Time IDs
|
359 |
+
added_time_ids = self._get_add_time_ids(
|
360 |
+
7,
|
361 |
+
127,
|
362 |
+
noise_aug_strength,
|
363 |
+
video_embeddings.dtype,
|
364 |
+
batch_size,
|
365 |
+
1,
|
366 |
+
False,
|
367 |
+
) # [1 or 2, 3]
|
368 |
+
added_time_ids = added_time_ids.to(device)
|
369 |
+
|
370 |
+
# 6. Prepare timesteps
|
371 |
+
timestep = 1.6378
|
372 |
+
self._num_timesteps = 1
|
373 |
+
|
374 |
+
# 7. Prepare latent variables
|
375 |
+
num_channels_latents = self.unet.config.in_channels
|
376 |
+
latents_init = prior_latents # [1, t, c, h, w]
|
377 |
+
latents_all = None
|
378 |
+
|
379 |
+
idx_start = 0
|
380 |
+
if overlap > 0:
|
381 |
+
weights = torch.linspace(0, 1, overlap, device=device)
|
382 |
+
weights = weights.view(1, overlap, 1, 1, 1)
|
383 |
+
else:
|
384 |
+
weights = None
|
385 |
+
|
386 |
+
while idx_start < num_frames - overlap:
|
387 |
+
idx_end = min(idx_start + window_size, num_frames)
|
388 |
+
# 9. Denoising loop
|
389 |
+
# latents_init = latents_init.flip(1)
|
390 |
+
latents = latents_init[:, idx_start:idx_end]
|
391 |
+
video_latents_current = video_latents[:, idx_start:idx_end]
|
392 |
+
video_embeddings_current = video_embeddings[:, idx_start:idx_end]
|
393 |
+
|
394 |
+
latent_model_input = torch.cat(
|
395 |
+
[latents, video_latents_current], dim=2
|
396 |
+
)
|
397 |
+
|
398 |
+
model_pred = self.unet(
|
399 |
+
latent_model_input,
|
400 |
+
timestep,
|
401 |
+
encoder_hidden_states=video_embeddings_current,
|
402 |
+
added_time_ids=added_time_ids,
|
403 |
+
return_dict=False,
|
404 |
+
)[0]
|
405 |
+
|
406 |
+
c_out = -1
|
407 |
+
latents = model_pred * c_out
|
408 |
+
|
409 |
+
if latents_all is None:
|
410 |
+
latents_all = latents.clone()
|
411 |
+
else:
|
412 |
+
if overlap > 0:
|
413 |
+
latents_all[:, -overlap:] = latents[
|
414 |
+
:, :overlap
|
415 |
+
] * weights + latents_all[:, -overlap:] * (1 - weights)
|
416 |
+
latents_all = torch.cat([latents_all, latents[:, overlap:]], dim=1)
|
417 |
+
|
418 |
+
idx_start += stride
|
419 |
+
|
420 |
+
latents_all = 1 / self.vae.config.scaling_factor * latents_all.squeeze(0).to(torch.float32)
|
421 |
+
|
422 |
+
if track_time:
|
423 |
+
denoise_event.record()
|
424 |
+
torch.cuda.synchronize()
|
425 |
+
elapsed_time_ms = encode_event.elapsed_time(denoise_event)
|
426 |
+
print(f"Elapsed time for denoise latent: {elapsed_time_ms} ms")
|
427 |
+
else:
|
428 |
+
gc.collect()
|
429 |
+
torch.cuda.empty_cache()
|
430 |
+
|
431 |
+
point_map, valid_mask = self.decode_point_map(
|
432 |
+
point_map_vae,
|
433 |
+
latents_all,
|
434 |
+
chunk_size=decode_chunk_size,
|
435 |
+
force_projection=force_projection,
|
436 |
+
force_fixed_focal=force_fixed_focal,
|
437 |
+
use_extract_interp=use_extract_interp,
|
438 |
+
need_resize=need_resize,
|
439 |
+
height=original_height,
|
440 |
+
width=original_width)
|
441 |
+
|
442 |
+
if track_time:
|
443 |
+
decode_event.record()
|
444 |
+
torch.cuda.synchronize()
|
445 |
+
elapsed_time_ms = denoise_event.elapsed_time(decode_event)
|
446 |
+
print(f"Elapsed time for decode latent: {elapsed_time_ms} ms")
|
447 |
+
else:
|
448 |
+
gc.collect()
|
449 |
+
torch.cuda.empty_cache()
|
450 |
+
|
451 |
+
self.maybe_free_model_hooks()
|
452 |
+
# t,h,w,3 t,h,w
|
453 |
+
return point_map, valid_mask
|
geometrycrafter/diff_ppl.py
ADDED
@@ -0,0 +1,526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Dict, List, Optional, Union
|
2 |
+
import gc
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
|
9 |
+
_resize_with_antialiasing,
|
10 |
+
StableVideoDiffusionPipeline,
|
11 |
+
retrieve_timesteps,
|
12 |
+
)
|
13 |
+
from diffusers.utils import logging
|
14 |
+
from kornia.utils import create_meshgrid
|
15 |
+
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
16 |
+
|
17 |
+
|
18 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
19 |
+
|
20 |
+
@torch.no_grad()
|
21 |
+
def normalize_point_map(point_map, valid_mask):
|
22 |
+
# T,H,W,3 T,H,W
|
23 |
+
norm_factor = (point_map[..., 2] * valid_mask.float()).mean() / (valid_mask.float().mean() + 1e-8)
|
24 |
+
norm_factor = norm_factor.clip(min=1e-3)
|
25 |
+
return point_map / norm_factor
|
26 |
+
|
27 |
+
def point_map_xy2intrinsic_map(point_map_xy):
|
28 |
+
# *,h,w,2
|
29 |
+
height, width = point_map_xy.shape[-3], point_map_xy.shape[-2]
|
30 |
+
assert height % 2 == 0
|
31 |
+
assert width % 2 == 0
|
32 |
+
mesh_grid = create_meshgrid(
|
33 |
+
height=height,
|
34 |
+
width=width,
|
35 |
+
normalized_coordinates=True,
|
36 |
+
device=point_map_xy.device,
|
37 |
+
dtype=point_map_xy.dtype
|
38 |
+
)[0] # h,w,2
|
39 |
+
assert mesh_grid.abs().min() > 1e-4
|
40 |
+
# *,h,w,2
|
41 |
+
mesh_grid = mesh_grid.expand_as(point_map_xy)
|
42 |
+
nc = point_map_xy.mean(dim=-2).mean(dim=-2) # *, 2
|
43 |
+
nc_map = nc[..., None, None, :].expand_as(point_map_xy)
|
44 |
+
nf = ((point_map_xy - nc_map) / mesh_grid).mean(dim=-2).mean(dim=-2)
|
45 |
+
nf_map = nf[..., None, None, :].expand_as(point_map_xy)
|
46 |
+
# print((mesh_grid * nf_map + nc_map - point_map_xy).abs().max())
|
47 |
+
|
48 |
+
return torch.cat([nc_map, nf_map], dim=-1)
|
49 |
+
|
50 |
+
def robust_min_max(tensor, quantile=0.99):
|
51 |
+
T, H, W = tensor.shape
|
52 |
+
min_vals = []
|
53 |
+
max_vals = []
|
54 |
+
for i in range(T):
|
55 |
+
min_vals.append(torch.quantile(tensor[i], q=1-quantile, interpolation='nearest').item())
|
56 |
+
max_vals.append(torch.quantile(tensor[i], q=quantile, interpolation='nearest').item())
|
57 |
+
return min(min_vals), max(max_vals)
|
58 |
+
|
59 |
+
class GeometryCrafterDiffPipeline(StableVideoDiffusionPipeline):
|
60 |
+
|
61 |
+
@torch.inference_mode()
|
62 |
+
def encode_video(
|
63 |
+
self,
|
64 |
+
video: torch.Tensor,
|
65 |
+
chunk_size: int = 14,
|
66 |
+
) -> torch.Tensor:
|
67 |
+
"""
|
68 |
+
:param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
|
69 |
+
:param chunk_size: the chunk size to encode video
|
70 |
+
:return: image_embeddings in shape of [b, 1024]
|
71 |
+
"""
|
72 |
+
|
73 |
+
video_224 = _resize_with_antialiasing(video.float(), (224, 224))
|
74 |
+
video_224 = (video_224 + 1.0) / 2.0 # [-1, 1] -> [0, 1]
|
75 |
+
embeddings = []
|
76 |
+
for i in range(0, video_224.shape[0], chunk_size):
|
77 |
+
emb = self.feature_extractor(
|
78 |
+
images=video_224[i : i + chunk_size],
|
79 |
+
do_normalize=True,
|
80 |
+
do_center_crop=False,
|
81 |
+
do_resize=False,
|
82 |
+
do_rescale=False,
|
83 |
+
return_tensors="pt",
|
84 |
+
).pixel_values.to(video.device, dtype=video.dtype)
|
85 |
+
embeddings.append(self.image_encoder(emb).image_embeds) # [b, 1024]
|
86 |
+
|
87 |
+
embeddings = torch.cat(embeddings, dim=0) # [t, 1024]
|
88 |
+
return embeddings
|
89 |
+
|
90 |
+
@torch.inference_mode()
|
91 |
+
def encode_vae_video(
|
92 |
+
self,
|
93 |
+
video: torch.Tensor,
|
94 |
+
chunk_size: int = 14,
|
95 |
+
):
|
96 |
+
"""
|
97 |
+
:param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
|
98 |
+
:param chunk_size: the chunk size to encode video
|
99 |
+
:return: vae latents in shape of [b, c, h, w]
|
100 |
+
"""
|
101 |
+
video_latents = []
|
102 |
+
for i in range(0, video.shape[0], chunk_size):
|
103 |
+
video_latents.append(
|
104 |
+
self.vae.encode(video[i : i + chunk_size]).latent_dist.mode()
|
105 |
+
)
|
106 |
+
video_latents = torch.cat(video_latents, dim=0)
|
107 |
+
return video_latents
|
108 |
+
|
109 |
+
@torch.inference_mode()
|
110 |
+
def produce_priors(self, prior_model, frame, chunk_size=8):
|
111 |
+
T, _, H, W = frame.shape
|
112 |
+
frame = (frame + 1) / 2
|
113 |
+
pred_point_maps = []
|
114 |
+
pred_masks = []
|
115 |
+
for i in range(0, len(frame), chunk_size):
|
116 |
+
pred_p, pred_m = prior_model.forward_image(frame[i:i+chunk_size])
|
117 |
+
pred_point_maps.append(pred_p)
|
118 |
+
pred_masks.append(pred_m)
|
119 |
+
pred_point_maps = torch.cat(pred_point_maps, dim=0)
|
120 |
+
pred_masks = torch.cat(pred_masks, dim=0)
|
121 |
+
|
122 |
+
pred_masks = pred_masks.float() * 2 - 1
|
123 |
+
|
124 |
+
# T,H,W,3 T,H,W
|
125 |
+
pred_point_maps = normalize_point_map(pred_point_maps, pred_masks > 0)
|
126 |
+
|
127 |
+
pred_disps = 1.0 / pred_point_maps[..., 2].clamp_min(1e-3)
|
128 |
+
pred_disps = pred_disps * (pred_masks > 0)
|
129 |
+
min_disparity, max_disparity = robust_min_max(pred_disps)
|
130 |
+
pred_disps = ((pred_disps - min_disparity) / (max_disparity - min_disparity+1e-4)).clamp(0, 1)
|
131 |
+
pred_disps = pred_disps * 2 - 1
|
132 |
+
|
133 |
+
pred_point_maps[..., :2] = pred_point_maps[..., :2] / (pred_point_maps[..., 2:3] + 1e-7)
|
134 |
+
pred_point_maps[..., 2] = torch.log(pred_point_maps[..., 2] + 1e-7) * (pred_masks > 0) # [x/z, y/z, log(z)]
|
135 |
+
|
136 |
+
pred_intr_maps = point_map_xy2intrinsic_map(pred_point_maps[..., :2]).permute(0,3,1,2) # T,H,W,2
|
137 |
+
pred_point_maps = pred_point_maps.permute(0,3,1,2)
|
138 |
+
|
139 |
+
return pred_disps, pred_masks, pred_point_maps, pred_intr_maps
|
140 |
+
|
141 |
+
@torch.inference_mode()
|
142 |
+
def encode_point_map(self, point_map_vae, disparity, valid_mask, point_map, intrinsic_map, chunk_size=8):
|
143 |
+
T, _, H, W = point_map.shape
|
144 |
+
latents = []
|
145 |
+
|
146 |
+
psedo_image = disparity[:, None].repeat(1,3,1,1)
|
147 |
+
intrinsic_map = torch.norm(intrinsic_map[:, 2:4], p=2, dim=1, keepdim=False)
|
148 |
+
|
149 |
+
for i in range(0, T, chunk_size):
|
150 |
+
latent_dist = self.vae.encode(psedo_image[i : i + chunk_size].to(self.vae.dtype)).latent_dist
|
151 |
+
latent_dist = point_map_vae.encode(
|
152 |
+
torch.cat([
|
153 |
+
intrinsic_map[i:i+chunk_size, None],
|
154 |
+
point_map[i:i+chunk_size, 2:3],
|
155 |
+
disparity[i:i+chunk_size, None],
|
156 |
+
valid_mask[i:i+chunk_size, None]], dim=1),
|
157 |
+
latent_dist
|
158 |
+
)
|
159 |
+
if isinstance(latent_dist, DiagonalGaussianDistribution):
|
160 |
+
latent = latent_dist.mode()
|
161 |
+
else:
|
162 |
+
latent = latent_dist
|
163 |
+
|
164 |
+
assert isinstance(latent, torch.Tensor)
|
165 |
+
latents.append(latent)
|
166 |
+
latents = torch.cat(latents, dim=0)
|
167 |
+
latents = latents * self.vae.config.scaling_factor
|
168 |
+
return latents
|
169 |
+
|
170 |
+
@torch.no_grad()
|
171 |
+
def decode_point_map(self, point_map_vae, latents, chunk_size=8, force_projection=True, force_fixed_focal=True, use_extract_interp=False, need_resize=False, height=None, width=None):
|
172 |
+
T = latents.shape[0]
|
173 |
+
rec_intrinsic_maps = []
|
174 |
+
rec_depth_maps = []
|
175 |
+
rec_valid_masks = []
|
176 |
+
for i in range(0, T, chunk_size):
|
177 |
+
lat = latents[i:i+chunk_size]
|
178 |
+
rec_imap, rec_dmap, rec_vmask = point_map_vae.decode(
|
179 |
+
lat,
|
180 |
+
num_frames=lat.shape[0],
|
181 |
+
)
|
182 |
+
rec_intrinsic_maps.append(rec_imap)
|
183 |
+
rec_depth_maps.append(rec_dmap)
|
184 |
+
rec_valid_masks.append(rec_vmask)
|
185 |
+
|
186 |
+
rec_intrinsic_maps = torch.cat(rec_intrinsic_maps, dim=0)
|
187 |
+
rec_depth_maps = torch.cat(rec_depth_maps, dim=0)
|
188 |
+
rec_valid_masks = torch.cat(rec_valid_masks, dim=0)
|
189 |
+
|
190 |
+
if need_resize:
|
191 |
+
rec_depth_maps = F.interpolate(rec_depth_maps, (height, width), mode='nearest-exact') if use_extract_interp else F.interpolate(rec_depth_maps, (height, width), mode='bilinear', align_corners=False)
|
192 |
+
rec_valid_masks = F.interpolate(rec_valid_masks, (height, width), mode='nearest-exact') if use_extract_interp else F.interpolate(rec_valid_masks, (height, width), mode='bilinear', align_corners=False)
|
193 |
+
rec_intrinsic_maps = F.interpolate(rec_intrinsic_maps, (height, width), mode='bilinear', align_corners=False)
|
194 |
+
|
195 |
+
H, W = rec_intrinsic_maps.shape[-2], rec_intrinsic_maps.shape[-1]
|
196 |
+
mesh_grid = create_meshgrid(
|
197 |
+
H, W,
|
198 |
+
normalized_coordinates=True
|
199 |
+
).to(rec_intrinsic_maps.device, rec_intrinsic_maps.dtype, non_blocking=True)
|
200 |
+
# 1,h,w,2
|
201 |
+
rec_intrinsic_maps = torch.cat([rec_intrinsic_maps * W / np.sqrt(W**2+H**2), rec_intrinsic_maps * H / np.sqrt(W**2+H**2)], dim=1) # t,2,h,w
|
202 |
+
mesh_grid = mesh_grid.permute(0,3,1,2)
|
203 |
+
rec_valid_masks = rec_valid_masks.squeeze(1) > 0
|
204 |
+
|
205 |
+
if force_projection:
|
206 |
+
if force_fixed_focal:
|
207 |
+
nfx = (rec_intrinsic_maps[:, 0, :, :] * rec_valid_masks.float()).mean() / (rec_valid_masks.float().mean() + 1e-4)
|
208 |
+
nfy = (rec_intrinsic_maps[:, 1, :, :] * rec_valid_masks.float()).mean() / (rec_valid_masks.float().mean() + 1e-4)
|
209 |
+
rec_intrinsic_maps = torch.tensor([nfx, nfy], device=rec_intrinsic_maps.device)[None, :, None, None].repeat(T, 1, 1, 1)
|
210 |
+
else:
|
211 |
+
nfx = (rec_intrinsic_maps[:, 0, :, :] * rec_valid_masks.float()).mean(dim=[-1, -2]) / (rec_valid_masks.float().mean(dim=[-1, -2]) + 1e-4)
|
212 |
+
nfy = (rec_intrinsic_maps[:, 1, :, :] * rec_valid_masks.float()).mean(dim=[-1, -2]) / (rec_valid_masks.float().mean(dim=[-1, -2]) + 1e-4)
|
213 |
+
rec_intrinsic_maps = torch.stack([nfx, nfy], dim=-1)[:, :, None, None]
|
214 |
+
# t,2,1,1
|
215 |
+
|
216 |
+
rec_point_maps = torch.cat([rec_intrinsic_maps * mesh_grid, rec_depth_maps], dim=1).permute(0,2,3,1)
|
217 |
+
xy, z = rec_point_maps.split([2, 1], dim=-1)
|
218 |
+
z = torch.clamp_max(z, 10) # for numerical stability
|
219 |
+
z = torch.exp(z)
|
220 |
+
rec_point_maps = torch.cat([xy * z, z], dim=-1)
|
221 |
+
|
222 |
+
return rec_point_maps, rec_valid_masks
|
223 |
+
|
224 |
+
|
225 |
+
@torch.no_grad()
|
226 |
+
def __call__(
|
227 |
+
self,
|
228 |
+
video: Union[np.ndarray, torch.Tensor],
|
229 |
+
point_map_vae,
|
230 |
+
prior_model,
|
231 |
+
height: int = 320,
|
232 |
+
width: int = 640,
|
233 |
+
num_inference_steps: int = 5,
|
234 |
+
guidance_scale: float = 1.0,
|
235 |
+
window_size: Optional[int] = 14,
|
236 |
+
noise_aug_strength: float = 0.02,
|
237 |
+
decode_chunk_size: Optional[int] = None,
|
238 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
239 |
+
latents: Optional[torch.FloatTensor] = None,
|
240 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
241 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
242 |
+
overlap: int = 4,
|
243 |
+
force_projection: bool = True,
|
244 |
+
force_fixed_focal: bool = True,
|
245 |
+
use_extract_interp: bool = False,
|
246 |
+
track_time: bool = False,
|
247 |
+
):
|
248 |
+
|
249 |
+
# video: in shape [t, h, w, c] if np.ndarray or [t, c, h, w] if torch.Tensor, in range [0, 1]
|
250 |
+
|
251 |
+
# 0. Default height and width to unet
|
252 |
+
if isinstance(video, np.ndarray):
|
253 |
+
video = torch.from_numpy(video.transpose(0, 3, 1, 2))
|
254 |
+
else:
|
255 |
+
assert isinstance(video, torch.Tensor)
|
256 |
+
height = height or video.shape[-2]
|
257 |
+
width = width or video.shape[-1]
|
258 |
+
original_height = video.shape[-2]
|
259 |
+
original_width = video.shape[-1]
|
260 |
+
num_frames = video.shape[0]
|
261 |
+
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 8
|
262 |
+
if num_frames <= window_size:
|
263 |
+
window_size = num_frames
|
264 |
+
overlap = 0
|
265 |
+
stride = window_size - overlap
|
266 |
+
|
267 |
+
# 1. Check inputs. Raise error if not correct
|
268 |
+
assert height % 64 == 0 and width % 64 == 0
|
269 |
+
if original_height != height or original_width != width:
|
270 |
+
need_resize = True
|
271 |
+
else:
|
272 |
+
need_resize = False
|
273 |
+
|
274 |
+
# 2. Define call parameters
|
275 |
+
batch_size = 1
|
276 |
+
device = self._execution_device
|
277 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
278 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
279 |
+
# corresponds to doing no classifier free guidance.
|
280 |
+
self._guidance_scale = guidance_scale
|
281 |
+
|
282 |
+
if track_time:
|
283 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
284 |
+
prior_event = torch.cuda.Event(enable_timing=True)
|
285 |
+
encode_event = torch.cuda.Event(enable_timing=True)
|
286 |
+
denoise_event = torch.cuda.Event(enable_timing=True)
|
287 |
+
decode_event = torch.cuda.Event(enable_timing=True)
|
288 |
+
start_event.record()
|
289 |
+
|
290 |
+
# 3. Encode input video
|
291 |
+
pred_disparity, pred_valid_mask, pred_point_map, pred_intrinsic_map = self.produce_priors(
|
292 |
+
prior_model,
|
293 |
+
video.to(device=device, dtype=torch.float32),
|
294 |
+
chunk_size=decode_chunk_size
|
295 |
+
) # T,H,W T,H,W T,3,H,W T,2,H,W
|
296 |
+
|
297 |
+
if need_resize:
|
298 |
+
pred_disparity = F.interpolate(pred_disparity.unsqueeze(1), (height, width), mode='bilinear', align_corners=False).squeeze(1)
|
299 |
+
pred_valid_mask = F.interpolate(pred_valid_mask.unsqueeze(1), (height, width), mode='bilinear', align_corners=False).squeeze(1)
|
300 |
+
pred_point_map = F.interpolate(pred_point_map, (height, width), mode='bilinear', align_corners=False)
|
301 |
+
pred_intrinsic_map = F.interpolate(pred_intrinsic_map, (height, width), mode='bilinear', align_corners=False)
|
302 |
+
|
303 |
+
|
304 |
+
if track_time:
|
305 |
+
prior_event.record()
|
306 |
+
torch.cuda.synchronize()
|
307 |
+
elapsed_time_ms = start_event.elapsed_time(prior_event)
|
308 |
+
print(f"Elapsed time for computing per-frame prior: {elapsed_time_ms} ms")
|
309 |
+
else:
|
310 |
+
gc.collect()
|
311 |
+
torch.cuda.empty_cache()
|
312 |
+
|
313 |
+
|
314 |
+
# 3. Encode input video
|
315 |
+
if need_resize:
|
316 |
+
video = F.interpolate(video, (height, width), mode="bicubic", align_corners=False, antialias=True).clamp(0, 1)
|
317 |
+
video = video.to(device=device, dtype=self.dtype)
|
318 |
+
video = video * 2.0 - 1.0 # [0,1] -> [-1,1], in [t, c, h, w]
|
319 |
+
|
320 |
+
video_embeddings = self.encode_video(video, chunk_size=decode_chunk_size).unsqueeze(0)
|
321 |
+
prior_latents = self.encode_point_map(
|
322 |
+
point_map_vae,
|
323 |
+
pred_disparity,
|
324 |
+
pred_valid_mask,
|
325 |
+
pred_point_map,
|
326 |
+
pred_intrinsic_map,
|
327 |
+
chunk_size=decode_chunk_size
|
328 |
+
).unsqueeze(0).to(video_embeddings.dtype) # 1,T,C,H,W
|
329 |
+
|
330 |
+
# 4. Encode input image using VAE
|
331 |
+
|
332 |
+
# pdb.set_trace()
|
333 |
+
needs_upcasting = (
|
334 |
+
self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
335 |
+
)
|
336 |
+
if needs_upcasting:
|
337 |
+
self.vae.to(dtype=torch.float32)
|
338 |
+
|
339 |
+
video_latents = self.encode_vae_video(
|
340 |
+
video.to(self.vae.dtype),
|
341 |
+
chunk_size=decode_chunk_size,
|
342 |
+
).unsqueeze(0).to(video_embeddings.dtype) # [1, t, c, h, w]
|
343 |
+
|
344 |
+
torch.cuda.empty_cache()
|
345 |
+
|
346 |
+
if track_time:
|
347 |
+
encode_event.record()
|
348 |
+
torch.cuda.synchronize()
|
349 |
+
elapsed_time_ms = prior_event.elapsed_time(encode_event)
|
350 |
+
print(f"Elapsed time for encode prior and frames: {elapsed_time_ms} ms")
|
351 |
+
else:
|
352 |
+
gc.collect()
|
353 |
+
torch.cuda.empty_cache()
|
354 |
+
|
355 |
+
# cast back to fp16 if needed
|
356 |
+
if needs_upcasting:
|
357 |
+
self.vae.to(dtype=torch.float16)
|
358 |
+
|
359 |
+
# 5. Get Added Time IDs
|
360 |
+
added_time_ids = self._get_add_time_ids(
|
361 |
+
7,
|
362 |
+
127,
|
363 |
+
noise_aug_strength,
|
364 |
+
video_embeddings.dtype,
|
365 |
+
batch_size,
|
366 |
+
1,
|
367 |
+
False,
|
368 |
+
) # [1 or 2, 3]
|
369 |
+
added_time_ids = added_time_ids.to(device)
|
370 |
+
|
371 |
+
# 6. Prepare timesteps
|
372 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
373 |
+
self.scheduler, num_inference_steps, device, None, None
|
374 |
+
)
|
375 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
376 |
+
self._num_timesteps = len(timesteps)
|
377 |
+
|
378 |
+
# 7. Prepare latent variables
|
379 |
+
# num_channels_latents = self.unet.config.in_channels - prior_latents.shape[1]
|
380 |
+
num_channels_latents = 8
|
381 |
+
latents_init = self.prepare_latents(
|
382 |
+
batch_size,
|
383 |
+
window_size,
|
384 |
+
num_channels_latents,
|
385 |
+
height,
|
386 |
+
width,
|
387 |
+
video_embeddings.dtype,
|
388 |
+
device,
|
389 |
+
generator,
|
390 |
+
latents,
|
391 |
+
) # [1, t, c, h, w]
|
392 |
+
latents_all = None
|
393 |
+
|
394 |
+
idx_start = 0
|
395 |
+
if overlap > 0:
|
396 |
+
weights = torch.linspace(0, 1, overlap, device=device)
|
397 |
+
weights = weights.view(1, overlap, 1, 1, 1)
|
398 |
+
else:
|
399 |
+
weights = None
|
400 |
+
|
401 |
+
while idx_start < num_frames - overlap:
|
402 |
+
idx_end = min(idx_start + window_size, num_frames)
|
403 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
404 |
+
# 9. Denoising loop
|
405 |
+
# latents_init = latents_init.flip(1)
|
406 |
+
latents = latents_init[:, : idx_end - idx_start].clone()
|
407 |
+
latents_init = torch.cat(
|
408 |
+
[latents_init[:, -overlap:], latents_init[:, :stride]], dim=1
|
409 |
+
)
|
410 |
+
|
411 |
+
video_latents_current = video_latents[:, idx_start:idx_end]
|
412 |
+
prior_latents_current = prior_latents[:, idx_start:idx_end]
|
413 |
+
video_embeddings_current = video_embeddings[:, idx_start:idx_end]
|
414 |
+
|
415 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
416 |
+
for i, t in enumerate(timesteps):
|
417 |
+
if latents_all is not None and i == 0:
|
418 |
+
latents[:, :overlap] = (
|
419 |
+
latents_all[:, -overlap:]
|
420 |
+
+ latents[:, :overlap]
|
421 |
+
/ self.scheduler.init_noise_sigma
|
422 |
+
* self.scheduler.sigmas[i]
|
423 |
+
)
|
424 |
+
|
425 |
+
latent_model_input = latents
|
426 |
+
|
427 |
+
latent_model_input = self.scheduler.scale_model_input(
|
428 |
+
latent_model_input, t
|
429 |
+
) # [1 or 2, t, c, h, w]
|
430 |
+
latent_model_input = torch.cat(
|
431 |
+
[latent_model_input, video_latents_current, prior_latents_current], dim=2
|
432 |
+
)
|
433 |
+
noise_pred = self.unet(
|
434 |
+
latent_model_input,
|
435 |
+
t,
|
436 |
+
encoder_hidden_states=video_embeddings_current,
|
437 |
+
added_time_ids=added_time_ids,
|
438 |
+
return_dict=False,
|
439 |
+
)[0]
|
440 |
+
# pdb.set_trace()
|
441 |
+
# perform guidance
|
442 |
+
if self.do_classifier_free_guidance:
|
443 |
+
latent_model_input = latents
|
444 |
+
latent_model_input = self.scheduler.scale_model_input(
|
445 |
+
latent_model_input, t
|
446 |
+
)
|
447 |
+
latent_model_input = torch.cat(
|
448 |
+
[latent_model_input, torch.zeros_like(latent_model_input), torch.zeros_like(latent_model_input)],
|
449 |
+
dim=2,
|
450 |
+
)
|
451 |
+
noise_pred_uncond = self.unet(
|
452 |
+
latent_model_input,
|
453 |
+
t,
|
454 |
+
encoder_hidden_states=torch.zeros_like(
|
455 |
+
video_embeddings_current
|
456 |
+
),
|
457 |
+
added_time_ids=added_time_ids,
|
458 |
+
return_dict=False,
|
459 |
+
)[0]
|
460 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (
|
461 |
+
noise_pred - noise_pred_uncond
|
462 |
+
)
|
463 |
+
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
464 |
+
|
465 |
+
if callback_on_step_end is not None:
|
466 |
+
callback_kwargs = {}
|
467 |
+
for k in callback_on_step_end_tensor_inputs:
|
468 |
+
callback_kwargs[k] = locals()[k]
|
469 |
+
callback_outputs = callback_on_step_end(
|
470 |
+
self, i, t, callback_kwargs
|
471 |
+
)
|
472 |
+
|
473 |
+
latents = callback_outputs.pop("latents", latents)
|
474 |
+
|
475 |
+
if i == len(timesteps) - 1 or (
|
476 |
+
(i + 1) > num_warmup_steps
|
477 |
+
and (i + 1) % self.scheduler.order == 0
|
478 |
+
):
|
479 |
+
progress_bar.update()
|
480 |
+
|
481 |
+
if latents_all is None:
|
482 |
+
latents_all = latents.clone()
|
483 |
+
else:
|
484 |
+
if overlap > 0:
|
485 |
+
latents_all[:, -overlap:] = latents[
|
486 |
+
:, :overlap
|
487 |
+
] * weights + latents_all[:, -overlap:] * (1 - weights)
|
488 |
+
latents_all = torch.cat([latents_all, latents[:, overlap:]], dim=1)
|
489 |
+
|
490 |
+
idx_start += stride
|
491 |
+
|
492 |
+
latents_all = 1 / self.vae.config.scaling_factor * latents_all.squeeze(0).to(torch.float32)
|
493 |
+
|
494 |
+
if track_time:
|
495 |
+
denoise_event.record()
|
496 |
+
torch.cuda.synchronize()
|
497 |
+
elapsed_time_ms = encode_event.elapsed_time(denoise_event)
|
498 |
+
print(f"Elapsed time for denoise latent: {elapsed_time_ms} ms")
|
499 |
+
else:
|
500 |
+
gc.collect()
|
501 |
+
torch.cuda.empty_cache()
|
502 |
+
|
503 |
+
point_map, valid_mask = self.decode_point_map(
|
504 |
+
point_map_vae,
|
505 |
+
latents_all,
|
506 |
+
chunk_size=decode_chunk_size,
|
507 |
+
force_projection=force_projection,
|
508 |
+
force_fixed_focal=force_fixed_focal,
|
509 |
+
use_extract_interp=use_extract_interp,
|
510 |
+
need_resize=need_resize,
|
511 |
+
height=original_height,
|
512 |
+
width=original_width)
|
513 |
+
|
514 |
+
|
515 |
+
if track_time:
|
516 |
+
decode_event.record()
|
517 |
+
torch.cuda.synchronize()
|
518 |
+
elapsed_time_ms = denoise_event.elapsed_time(decode_event)
|
519 |
+
print(f"Elapsed time for decode latent: {elapsed_time_ms} ms")
|
520 |
+
else:
|
521 |
+
gc.collect()
|
522 |
+
torch.cuda.empty_cache()
|
523 |
+
|
524 |
+
self.maybe_free_model_hooks()
|
525 |
+
# t,h,w,3 t,h,w
|
526 |
+
return point_map, valid_mask
|
geometrycrafter/pmap_vae.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
6 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
7 |
+
from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
|
8 |
+
from diffusers.models.modeling_utils import ModelMixin
|
9 |
+
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution, Encoder
|
10 |
+
from diffusers.utils import is_torch_version
|
11 |
+
from diffusers.models.unets.unet_3d_blocks import UpBlockTemporalDecoder, MidBlockTemporalDecoder
|
12 |
+
from diffusers.models.resnet import SpatioTemporalResBlock
|
13 |
+
|
14 |
+
def zero_module(module):
|
15 |
+
"""
|
16 |
+
Zero out the parameters of a module and return it.
|
17 |
+
"""
|
18 |
+
for p in module.parameters():
|
19 |
+
p.detach().zero_()
|
20 |
+
return module
|
21 |
+
|
22 |
+
class PMapTemporalDecoder(nn.Module):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
in_channels: int = 4,
|
26 |
+
out_channels: Tuple[int] = (1, 1, 1),
|
27 |
+
block_out_channels: Tuple[int] = (128, 256, 512, 512),
|
28 |
+
layers_per_block: int = 2,
|
29 |
+
):
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
self.conv_in = nn.Conv2d(
|
33 |
+
in_channels,
|
34 |
+
block_out_channels[-1],
|
35 |
+
kernel_size=3,
|
36 |
+
stride=1,
|
37 |
+
padding=1
|
38 |
+
)
|
39 |
+
self.mid_block = MidBlockTemporalDecoder(
|
40 |
+
num_layers=layers_per_block,
|
41 |
+
in_channels=block_out_channels[-1],
|
42 |
+
out_channels=block_out_channels[-1],
|
43 |
+
attention_head_dim=block_out_channels[-1],
|
44 |
+
)
|
45 |
+
|
46 |
+
# up
|
47 |
+
self.up_blocks = nn.ModuleList([])
|
48 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
49 |
+
output_channel = reversed_block_out_channels[0]
|
50 |
+
for i in range(len(block_out_channels)):
|
51 |
+
prev_output_channel = output_channel
|
52 |
+
output_channel = reversed_block_out_channels[i]
|
53 |
+
is_final_block = i == len(block_out_channels) - 1
|
54 |
+
up_block = UpBlockTemporalDecoder(
|
55 |
+
num_layers=layers_per_block + 1,
|
56 |
+
in_channels=prev_output_channel,
|
57 |
+
out_channels=output_channel,
|
58 |
+
add_upsample=not is_final_block,
|
59 |
+
)
|
60 |
+
self.up_blocks.append(up_block)
|
61 |
+
prev_output_channel = output_channel
|
62 |
+
|
63 |
+
self.out_blocks = nn.ModuleList([])
|
64 |
+
self.time_conv_outs = nn.ModuleList([])
|
65 |
+
for out_channel in out_channels:
|
66 |
+
self.out_blocks.append(
|
67 |
+
nn.ModuleList([
|
68 |
+
nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-6),
|
69 |
+
nn.ReLU(inplace=True),
|
70 |
+
nn.Conv2d(
|
71 |
+
block_out_channels[0],
|
72 |
+
block_out_channels[0] // 2,
|
73 |
+
kernel_size=3,
|
74 |
+
padding=1
|
75 |
+
),
|
76 |
+
SpatioTemporalResBlock(
|
77 |
+
in_channels=block_out_channels[0] // 2,
|
78 |
+
out_channels=block_out_channels[0] // 2,
|
79 |
+
temb_channels=None,
|
80 |
+
eps=1e-6,
|
81 |
+
temporal_eps=1e-5,
|
82 |
+
merge_factor=0.0,
|
83 |
+
merge_strategy="learned",
|
84 |
+
switch_spatial_to_temporal_mix=True
|
85 |
+
),
|
86 |
+
nn.ReLU(inplace=True),
|
87 |
+
nn.Conv2d(
|
88 |
+
block_out_channels[0] // 2,
|
89 |
+
out_channel,
|
90 |
+
kernel_size=1,
|
91 |
+
)
|
92 |
+
])
|
93 |
+
)
|
94 |
+
|
95 |
+
conv_out_kernel_size = (3, 1, 1)
|
96 |
+
padding = [int(k // 2) for k in conv_out_kernel_size]
|
97 |
+
self.time_conv_outs.append(nn.Conv3d(
|
98 |
+
in_channels=out_channel,
|
99 |
+
out_channels=out_channel,
|
100 |
+
kernel_size=conv_out_kernel_size,
|
101 |
+
padding=padding,
|
102 |
+
))
|
103 |
+
|
104 |
+
self.gradient_checkpointing = False
|
105 |
+
|
106 |
+
def forward(
|
107 |
+
self,
|
108 |
+
sample: torch.Tensor,
|
109 |
+
image_only_indicator: torch.Tensor,
|
110 |
+
num_frames: int = 1,
|
111 |
+
):
|
112 |
+
sample = self.conv_in(sample)
|
113 |
+
|
114 |
+
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
115 |
+
|
116 |
+
if self.training and self.gradient_checkpointing:
|
117 |
+
def create_custom_forward(module):
|
118 |
+
def custom_forward(*inputs):
|
119 |
+
return module(*inputs)
|
120 |
+
|
121 |
+
return custom_forward
|
122 |
+
|
123 |
+
if is_torch_version(">=", "1.11.0"):
|
124 |
+
# middle
|
125 |
+
sample = torch.utils.checkpoint.checkpoint(
|
126 |
+
create_custom_forward(self.mid_block),
|
127 |
+
sample,
|
128 |
+
image_only_indicator,
|
129 |
+
use_reentrant=False,
|
130 |
+
)
|
131 |
+
sample = sample.to(upscale_dtype)
|
132 |
+
|
133 |
+
# up
|
134 |
+
for up_block in self.up_blocks:
|
135 |
+
sample = torch.utils.checkpoint.checkpoint(
|
136 |
+
create_custom_forward(up_block),
|
137 |
+
sample,
|
138 |
+
image_only_indicator,
|
139 |
+
use_reentrant=False,
|
140 |
+
)
|
141 |
+
else:
|
142 |
+
# middle
|
143 |
+
sample = torch.utils.checkpoint.checkpoint(
|
144 |
+
create_custom_forward(self.mid_block),
|
145 |
+
sample,
|
146 |
+
image_only_indicator,
|
147 |
+
)
|
148 |
+
sample = sample.to(upscale_dtype)
|
149 |
+
|
150 |
+
# up
|
151 |
+
for up_block in self.up_blocks:
|
152 |
+
sample = torch.utils.checkpoint.checkpoint(
|
153 |
+
create_custom_forward(up_block),
|
154 |
+
sample,
|
155 |
+
image_only_indicator,
|
156 |
+
)
|
157 |
+
else:
|
158 |
+
# middle
|
159 |
+
sample = self.mid_block(sample, image_only_indicator=image_only_indicator)
|
160 |
+
sample = sample.to(upscale_dtype)
|
161 |
+
|
162 |
+
# up
|
163 |
+
for up_block in self.up_blocks:
|
164 |
+
sample = up_block(sample, image_only_indicator=image_only_indicator)
|
165 |
+
|
166 |
+
# post-process
|
167 |
+
|
168 |
+
output = []
|
169 |
+
|
170 |
+
for out_block, time_conv_out in zip(self.out_blocks, self.time_conv_outs):
|
171 |
+
x = sample
|
172 |
+
for layer in out_block:
|
173 |
+
if isinstance(layer, SpatioTemporalResBlock):
|
174 |
+
x = layer(x, None, image_only_indicator)
|
175 |
+
else:
|
176 |
+
x = layer(x)
|
177 |
+
|
178 |
+
|
179 |
+
batch_frames, channels, height, width = x.shape
|
180 |
+
batch_size = batch_frames // num_frames
|
181 |
+
x = x[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
|
182 |
+
x = time_conv_out(x)
|
183 |
+
x = x.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
|
184 |
+
output.append(x)
|
185 |
+
|
186 |
+
return output
|
187 |
+
|
188 |
+
class PMapAutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
|
189 |
+
|
190 |
+
_supports_gradient_checkpointing = True
|
191 |
+
|
192 |
+
@register_to_config
|
193 |
+
def __init__(
|
194 |
+
self,
|
195 |
+
in_channels: int = 4,
|
196 |
+
latent_channels: int = 4,
|
197 |
+
enc_down_block_types: Tuple[str] = (
|
198 |
+
"DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"
|
199 |
+
),
|
200 |
+
enc_block_out_channels: Tuple[int] = (128, 256, 512, 512),
|
201 |
+
enc_layers_per_block: int = 2,
|
202 |
+
dec_block_out_channels: Tuple[int] = (128, 256, 512, 512),
|
203 |
+
dec_layers_per_block: int = 2,
|
204 |
+
out_channels: Tuple[int] = (1, 1, 1),
|
205 |
+
mid_block_add_attention: bool = True,
|
206 |
+
offset_scale_factor: float = 0.1,
|
207 |
+
**kwargs
|
208 |
+
):
|
209 |
+
super().__init__()
|
210 |
+
|
211 |
+
self.encoder = Encoder(
|
212 |
+
in_channels=in_channels,
|
213 |
+
out_channels=latent_channels,
|
214 |
+
down_block_types=enc_down_block_types,
|
215 |
+
block_out_channels=enc_block_out_channels,
|
216 |
+
layers_per_block=enc_layers_per_block,
|
217 |
+
double_z=False,
|
218 |
+
mid_block_add_attention=mid_block_add_attention
|
219 |
+
)
|
220 |
+
zero_module(self.encoder.conv_out)
|
221 |
+
|
222 |
+
self.offset_scale_factor = offset_scale_factor
|
223 |
+
|
224 |
+
self.decoder = PMapTemporalDecoder(
|
225 |
+
in_channels=latent_channels,
|
226 |
+
block_out_channels=dec_block_out_channels,
|
227 |
+
layers_per_block=dec_layers_per_block,
|
228 |
+
out_channels=out_channels
|
229 |
+
)
|
230 |
+
|
231 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
232 |
+
if isinstance(module, (Encoder, PMapTemporalDecoder)):
|
233 |
+
module.gradient_checkpointing = value
|
234 |
+
|
235 |
+
@property
|
236 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
237 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
238 |
+
r"""
|
239 |
+
Returns:
|
240 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
241 |
+
indexed by its weight name.
|
242 |
+
"""
|
243 |
+
# set recursively
|
244 |
+
processors = {}
|
245 |
+
|
246 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
247 |
+
if hasattr(module, "get_processor"):
|
248 |
+
processors[f"{name}.processor"] = module.get_processor()
|
249 |
+
|
250 |
+
for sub_name, child in module.named_children():
|
251 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
252 |
+
|
253 |
+
return processors
|
254 |
+
|
255 |
+
for name, module in self.named_children():
|
256 |
+
fn_recursive_add_processors(name, module, processors)
|
257 |
+
|
258 |
+
return processors
|
259 |
+
|
260 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
261 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
262 |
+
r"""
|
263 |
+
Sets the attention processor to use to compute attention.
|
264 |
+
|
265 |
+
Parameters:
|
266 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
267 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
268 |
+
for **all** `Attention` layers.
|
269 |
+
|
270 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
271 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
272 |
+
|
273 |
+
"""
|
274 |
+
count = len(self.attn_processors.keys())
|
275 |
+
|
276 |
+
if isinstance(processor, dict) and len(processor) != count:
|
277 |
+
raise ValueError(
|
278 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
279 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
280 |
+
)
|
281 |
+
|
282 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
283 |
+
if hasattr(module, "set_processor"):
|
284 |
+
if not isinstance(processor, dict):
|
285 |
+
module.set_processor(processor)
|
286 |
+
else:
|
287 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
288 |
+
|
289 |
+
for sub_name, child in module.named_children():
|
290 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
291 |
+
|
292 |
+
for name, module in self.named_children():
|
293 |
+
fn_recursive_attn_processor(name, module, processor)
|
294 |
+
|
295 |
+
def set_default_attn_processor(self):
|
296 |
+
"""
|
297 |
+
Disables custom attention processors and sets the default attention implementation.
|
298 |
+
"""
|
299 |
+
if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
300 |
+
processor = AttnProcessor()
|
301 |
+
else:
|
302 |
+
raise ValueError(
|
303 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
304 |
+
)
|
305 |
+
|
306 |
+
self.set_attn_processor(processor)
|
307 |
+
|
308 |
+
@apply_forward_hook
|
309 |
+
def encode(
|
310 |
+
self,
|
311 |
+
x: torch.Tensor,
|
312 |
+
latent_dist: DiagonalGaussianDistribution
|
313 |
+
) -> DiagonalGaussianDistribution:
|
314 |
+
h = self.encoder(x)
|
315 |
+
offset = h * self.offset_scale_factor
|
316 |
+
param = latent_dist.parameters.to(h.dtype)
|
317 |
+
mean, logvar = torch.chunk(param, 2, dim=1)
|
318 |
+
posterior = DiagonalGaussianDistribution(torch.cat([mean + offset, logvar], dim=1))
|
319 |
+
return posterior
|
320 |
+
|
321 |
+
@apply_forward_hook
|
322 |
+
def decode(
|
323 |
+
self,
|
324 |
+
z: torch.Tensor,
|
325 |
+
num_frames: int
|
326 |
+
) -> torch.Tensor:
|
327 |
+
batch_size = z.shape[0] // num_frames
|
328 |
+
image_only_indicator = torch.zeros(batch_size, num_frames, dtype=z.dtype, device=z.device)
|
329 |
+
decoded = self.decoder(z, num_frames=num_frames, image_only_indicator=image_only_indicator)
|
330 |
+
return decoded
|
geometrycrafter/unet.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from diffusers import UNetSpatioTemporalConditionModel
|
5 |
+
from diffusers.models.unets.unet_spatio_temporal_condition import UNetSpatioTemporalConditionOutput
|
6 |
+
from diffusers.utils import is_torch_version
|
7 |
+
|
8 |
+
|
9 |
+
class UNetSpatioTemporalConditionModelVid2vid(
|
10 |
+
UNetSpatioTemporalConditionModel
|
11 |
+
):
|
12 |
+
def enable_gradient_checkpointing(self):
|
13 |
+
self.gradient_checkpointing = True
|
14 |
+
|
15 |
+
def disable_gradient_checkpointing(self):
|
16 |
+
self.gradient_checkpointing = False
|
17 |
+
|
18 |
+
def forward(
|
19 |
+
self,
|
20 |
+
sample: torch.Tensor,
|
21 |
+
timestep: Union[torch.Tensor, float, int],
|
22 |
+
encoder_hidden_states: torch.Tensor,
|
23 |
+
added_time_ids: torch.Tensor,
|
24 |
+
return_dict: bool = True,
|
25 |
+
) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
|
26 |
+
|
27 |
+
# 1. time
|
28 |
+
timesteps = timestep
|
29 |
+
if not torch.is_tensor(timesteps):
|
30 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
31 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
32 |
+
is_mps = sample.device.type == "mps"
|
33 |
+
if isinstance(timestep, float):
|
34 |
+
dtype = torch.float32 if is_mps else torch.float64
|
35 |
+
else:
|
36 |
+
dtype = torch.int32 if is_mps else torch.int64
|
37 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
38 |
+
elif len(timesteps.shape) == 0:
|
39 |
+
timesteps = timesteps[None].to(sample.device)
|
40 |
+
|
41 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
42 |
+
batch_size, num_frames = sample.shape[:2]
|
43 |
+
timesteps = timesteps.expand(batch_size)
|
44 |
+
|
45 |
+
t_emb = self.time_proj(timesteps)
|
46 |
+
|
47 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
48 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
49 |
+
# there might be better ways to encapsulate this.
|
50 |
+
t_emb = t_emb.to(dtype=self.conv_in.weight.dtype)
|
51 |
+
|
52 |
+
emb = self.time_embedding(t_emb) # [batch_size * num_frames, channels]
|
53 |
+
|
54 |
+
time_embeds = self.add_time_proj(added_time_ids.flatten())
|
55 |
+
time_embeds = time_embeds.reshape((batch_size, -1))
|
56 |
+
time_embeds = time_embeds.to(emb.dtype)
|
57 |
+
aug_emb = self.add_embedding(time_embeds)
|
58 |
+
emb = emb + aug_emb
|
59 |
+
|
60 |
+
# Flatten the batch and frames dimensions
|
61 |
+
# sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
|
62 |
+
sample = sample.flatten(0, 1)
|
63 |
+
# Repeat the embeddings num_video_frames times
|
64 |
+
# emb: [batch, channels] -> [batch * frames, channels]
|
65 |
+
emb = emb.repeat_interleave(num_frames, dim=0)
|
66 |
+
# encoder_hidden_states: [batch, frames, channels] -> [batch * frames, 1, channels]
|
67 |
+
encoder_hidden_states = encoder_hidden_states.flatten(0, 1).unsqueeze(1)
|
68 |
+
|
69 |
+
# 2. pre-process
|
70 |
+
sample = sample.to(dtype=self.conv_in.weight.dtype)
|
71 |
+
assert sample.dtype == self.conv_in.weight.dtype, (
|
72 |
+
f"sample.dtype: {sample.dtype}, "
|
73 |
+
f"self.conv_in.weight.dtype: {self.conv_in.weight.dtype}"
|
74 |
+
)
|
75 |
+
sample = self.conv_in(sample)
|
76 |
+
|
77 |
+
image_only_indicator = torch.zeros(
|
78 |
+
batch_size, num_frames, dtype=sample.dtype, device=sample.device
|
79 |
+
)
|
80 |
+
|
81 |
+
down_block_res_samples = (sample,)
|
82 |
+
|
83 |
+
if self.training and self.gradient_checkpointing:
|
84 |
+
def create_custom_forward(module):
|
85 |
+
def custom_forward(*inputs):
|
86 |
+
return module(*inputs)
|
87 |
+
|
88 |
+
return custom_forward
|
89 |
+
|
90 |
+
if is_torch_version(">=", "1.11.0"):
|
91 |
+
|
92 |
+
for downsample_block in self.down_blocks:
|
93 |
+
if (
|
94 |
+
hasattr(downsample_block, "has_cross_attention")
|
95 |
+
and downsample_block.has_cross_attention
|
96 |
+
):
|
97 |
+
sample, res_samples = torch.utils.checkpoint.checkpoint(
|
98 |
+
create_custom_forward(downsample_block),
|
99 |
+
sample,
|
100 |
+
emb,
|
101 |
+
encoder_hidden_states,
|
102 |
+
image_only_indicator,
|
103 |
+
use_reentrant=False,
|
104 |
+
)
|
105 |
+
else:
|
106 |
+
sample, res_samples = torch.utils.checkpoint.checkpoint(
|
107 |
+
create_custom_forward(downsample_block),
|
108 |
+
sample,
|
109 |
+
emb,
|
110 |
+
image_only_indicator,
|
111 |
+
use_reentrant=False,
|
112 |
+
)
|
113 |
+
down_block_res_samples += res_samples
|
114 |
+
|
115 |
+
# 4. mid
|
116 |
+
sample = torch.utils.checkpoint.checkpoint(
|
117 |
+
create_custom_forward(self.mid_block),
|
118 |
+
sample,
|
119 |
+
emb,
|
120 |
+
encoder_hidden_states,
|
121 |
+
image_only_indicator,
|
122 |
+
use_reentrant=False,
|
123 |
+
)
|
124 |
+
|
125 |
+
# 5. up
|
126 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
127 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
128 |
+
down_block_res_samples = down_block_res_samples[
|
129 |
+
: -len(upsample_block.resnets)
|
130 |
+
]
|
131 |
+
|
132 |
+
if (
|
133 |
+
hasattr(upsample_block, "has_cross_attention")
|
134 |
+
and upsample_block.has_cross_attention
|
135 |
+
):
|
136 |
+
sample = torch.utils.checkpoint.checkpoint(
|
137 |
+
create_custom_forward(upsample_block),
|
138 |
+
sample,
|
139 |
+
res_samples,
|
140 |
+
emb,
|
141 |
+
encoder_hidden_states,
|
142 |
+
image_only_indicator,
|
143 |
+
use_reentrant=False,
|
144 |
+
)
|
145 |
+
else:
|
146 |
+
sample = torch.utils.checkpoint.checkpoint(
|
147 |
+
create_custom_forward(upsample_block),
|
148 |
+
sample,
|
149 |
+
res_samples,
|
150 |
+
emb,
|
151 |
+
image_only_indicator,
|
152 |
+
use_reentrant=False,
|
153 |
+
)
|
154 |
+
else:
|
155 |
+
|
156 |
+
for downsample_block in self.down_blocks:
|
157 |
+
if (
|
158 |
+
hasattr(downsample_block, "has_cross_attention")
|
159 |
+
and downsample_block.has_cross_attention
|
160 |
+
):
|
161 |
+
sample, res_samples = torch.utils.checkpoint.checkpoint(
|
162 |
+
create_custom_forward(downsample_block),
|
163 |
+
sample,
|
164 |
+
emb,
|
165 |
+
encoder_hidden_states,
|
166 |
+
image_only_indicator,
|
167 |
+
)
|
168 |
+
else:
|
169 |
+
sample, res_samples = torch.utils.checkpoint.checkpoint(
|
170 |
+
create_custom_forward(downsample_block),
|
171 |
+
sample,
|
172 |
+
emb,
|
173 |
+
image_only_indicator,
|
174 |
+
)
|
175 |
+
down_block_res_samples += res_samples
|
176 |
+
|
177 |
+
# 4. mid
|
178 |
+
sample = torch.utils.checkpoint.checkpoint(
|
179 |
+
create_custom_forward(self.mid_block),
|
180 |
+
sample,
|
181 |
+
emb,
|
182 |
+
encoder_hidden_states,
|
183 |
+
image_only_indicator,
|
184 |
+
)
|
185 |
+
|
186 |
+
# 5. up
|
187 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
188 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
189 |
+
down_block_res_samples = down_block_res_samples[
|
190 |
+
: -len(upsample_block.resnets)
|
191 |
+
]
|
192 |
+
|
193 |
+
if (
|
194 |
+
hasattr(upsample_block, "has_cross_attention")
|
195 |
+
and upsample_block.has_cross_attention
|
196 |
+
):
|
197 |
+
sample = torch.utils.checkpoint.checkpoint(
|
198 |
+
create_custom_forward(upsample_block),
|
199 |
+
sample,
|
200 |
+
res_samples,
|
201 |
+
emb,
|
202 |
+
encoder_hidden_states,
|
203 |
+
image_only_indicator,
|
204 |
+
)
|
205 |
+
else:
|
206 |
+
sample = torch.utils.checkpoint.checkpoint(
|
207 |
+
create_custom_forward(upsample_block),
|
208 |
+
sample,
|
209 |
+
res_samples,
|
210 |
+
emb,
|
211 |
+
image_only_indicator,
|
212 |
+
)
|
213 |
+
|
214 |
+
else:
|
215 |
+
for downsample_block in self.down_blocks:
|
216 |
+
if (
|
217 |
+
hasattr(downsample_block, "has_cross_attention")
|
218 |
+
and downsample_block.has_cross_attention
|
219 |
+
):
|
220 |
+
sample, res_samples = downsample_block(
|
221 |
+
hidden_states=sample,
|
222 |
+
temb=emb,
|
223 |
+
encoder_hidden_states=encoder_hidden_states,
|
224 |
+
image_only_indicator=image_only_indicator,
|
225 |
+
)
|
226 |
+
|
227 |
+
else:
|
228 |
+
sample, res_samples = downsample_block(
|
229 |
+
hidden_states=sample,
|
230 |
+
temb=emb,
|
231 |
+
image_only_indicator=image_only_indicator,
|
232 |
+
)
|
233 |
+
|
234 |
+
down_block_res_samples += res_samples
|
235 |
+
|
236 |
+
# 4. mid
|
237 |
+
sample = self.mid_block(
|
238 |
+
hidden_states=sample,
|
239 |
+
temb=emb,
|
240 |
+
encoder_hidden_states=encoder_hidden_states,
|
241 |
+
image_only_indicator=image_only_indicator,
|
242 |
+
)
|
243 |
+
|
244 |
+
# 5. up
|
245 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
246 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
247 |
+
down_block_res_samples = down_block_res_samples[
|
248 |
+
: -len(upsample_block.resnets)
|
249 |
+
]
|
250 |
+
|
251 |
+
if (
|
252 |
+
hasattr(upsample_block, "has_cross_attention")
|
253 |
+
and upsample_block.has_cross_attention
|
254 |
+
):
|
255 |
+
sample = upsample_block(
|
256 |
+
hidden_states=sample,
|
257 |
+
res_hidden_states_tuple=res_samples,
|
258 |
+
temb=emb,
|
259 |
+
encoder_hidden_states=encoder_hidden_states,
|
260 |
+
image_only_indicator=image_only_indicator,
|
261 |
+
)
|
262 |
+
else:
|
263 |
+
sample = upsample_block(
|
264 |
+
hidden_states=sample,
|
265 |
+
res_hidden_states_tuple=res_samples,
|
266 |
+
temb=emb,
|
267 |
+
image_only_indicator=image_only_indicator,
|
268 |
+
)
|
269 |
+
|
270 |
+
# 6. post-process
|
271 |
+
sample = self.conv_norm_out(sample)
|
272 |
+
sample = self.conv_act(sample)
|
273 |
+
sample = self.conv_out(sample)
|
274 |
+
|
275 |
+
# 7. Reshape back to original shape
|
276 |
+
sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
|
277 |
+
|
278 |
+
if not return_dict:
|
279 |
+
return (sample,)
|
280 |
+
|
281 |
+
return UNetSpatioTemporalConditionOutput(sample=sample)
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.3.1
|
2 |
+
diffusers==0.31.0
|
3 |
+
numpy==2.0.1
|
4 |
+
matplotlib==3.9.2
|
5 |
+
transformers==4.48.0
|
6 |
+
accelerate==1.1.1
|
7 |
+
xformers==0.0.27
|
8 |
+
mediapy==1.2.2
|
9 |
+
fire==0.7.0
|
10 |
+
decord==0.6.0
|
11 |
+
OpenEXR==3.3.2
|
12 |
+
kornia==0.7.4
|
13 |
+
opencv-python==4.10.0.84
|
14 |
+
h5py==3.12.1
|
15 |
+
moderngl==5.12.0
|
16 |
+
piqp==0.4.2
|
third_party/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import sys
|
4 |
+
|
5 |
+
sys.path.append('third_party/moge')
|
6 |
+
from .moge.moge.model.moge_model import MoGeModel
|
7 |
+
|
8 |
+
class MoGe(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, cache_dir):
|
11 |
+
super().__init__()
|
12 |
+
self.model = MoGeModel.from_pretrained(
|
13 |
+
'Ruicheng/moge-vitl', cache_dir=cache_dir).eval()
|
14 |
+
|
15 |
+
|
16 |
+
@torch.no_grad()
|
17 |
+
def forward_image(self, image: torch.Tensor, **kwargs):
|
18 |
+
# image: b, 3, h, w 0,1
|
19 |
+
output = self.model.infer(image, resolution_level=9, apply_mask=False, **kwargs)
|
20 |
+
points = output['points'] # b,h,w,3
|
21 |
+
masks = output['mask'] # b,h,w
|
22 |
+
return points, masks
|
third_party/moge
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit dd158c05461f2353287a182afb2adf0fda46436f
|
utils/__init__.py
ADDED
File without changes
|
utils/disp_utils.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from matplotlib import cm
|
3 |
+
|
4 |
+
def robust_min_max(tensor, quantile=0.99):
|
5 |
+
T, H, W = tensor.shape
|
6 |
+
min_vals = []
|
7 |
+
max_vals = []
|
8 |
+
for i in range(T):
|
9 |
+
min_vals.append(torch.quantile(tensor[i], q=1-quantile, interpolation='nearest').item())
|
10 |
+
max_vals.append(torch.quantile(tensor[i], q=quantile, interpolation='nearest').item())
|
11 |
+
return min(min_vals), max(max_vals)
|
12 |
+
|
13 |
+
|
14 |
+
class ColorMapper:
|
15 |
+
def __init__(self, colormap: str = "inferno"):
|
16 |
+
self.colormap = torch.tensor(cm.get_cmap(colormap).colors)
|
17 |
+
|
18 |
+
def apply(self, image: torch.Tensor, v_min=None, v_max=None):
|
19 |
+
# assert len(image.shape) == 2
|
20 |
+
if v_min is None:
|
21 |
+
v_min = image.min()
|
22 |
+
if v_max is None:
|
23 |
+
v_max = image.max()
|
24 |
+
image = (image - v_min) / (v_max - v_min)
|
25 |
+
image = (image * 255).long()
|
26 |
+
colormap = self.colormap.to(image.device)
|
27 |
+
image = colormap[image]
|
28 |
+
return image
|
29 |
+
|
30 |
+
def color_video_disp(disp):
|
31 |
+
visualizer = ColorMapper()
|
32 |
+
disp_img = visualizer.apply(disp, v_min=0, v_max=1)
|
33 |
+
return disp_img
|
34 |
+
|
35 |
+
def pmap_to_disp(point_maps, valid_masks):
|
36 |
+
disp_map = 1.0 / (point_maps[..., 2] + 1e-4)
|
37 |
+
min_disparity, max_disparity = robust_min_max(disp_map)
|
38 |
+
disp_map = torch.clamp((disp_map - min_disparity) / (max_disparity - min_disparity+1e-4), 0, 1)
|
39 |
+
|
40 |
+
disp_map = color_video_disp(disp_map)
|
41 |
+
disp_map[~valid_masks] = 0
|
42 |
+
return disp_map
|
43 |
+
# imageio.mimsave(os.path.join(args.save_dir, os.path.basename(args.data[:-4])+'_disp.mp4'), disp, fps=24, quality=9, macro_block_size=1)
|
utils/glb_utils.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import trimesh
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def pmap_to_glb(point_map, valid_mask, frame) -> trimesh.Scene:
|
5 |
+
|
6 |
+
|
7 |
+
pts_3d = point_map[valid_mask] * np.array([-1, -1, 1])
|
8 |
+
pts_rgb = frame[valid_mask]
|
9 |
+
|
10 |
+
# Initialize a 3D scene
|
11 |
+
scene_3d = trimesh.Scene()
|
12 |
+
|
13 |
+
# Add point cloud data to the scene
|
14 |
+
point_cloud_data = trimesh.PointCloud(
|
15 |
+
vertices=pts_3d, colors=pts_rgb
|
16 |
+
)
|
17 |
+
|
18 |
+
scene_3d.add_geometry(point_cloud_data)
|
19 |
+
return scene_3d
|