Spaces:
Running
on
Zero
Running
on
Zero
Beijia11
commited on
Commit
·
3aba902
1
Parent(s):
686bb9b
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +200 -0
- .gitmodules +3 -0
- app.py +577 -0
- config/__init__.py +0 -0
- config/base_cfg.py +410 -0
- config/ssm_cfg.py +347 -0
- config/yacs.py +506 -0
- demo.py +206 -0
- models/cogvideox_tracking.py +1020 -0
- models/pipelines.py +1040 -0
- models/spatracker/__init__.py +5 -0
- models/spatracker/models/__init__.py +5 -0
- models/spatracker/models/build_spatracker.py +51 -0
- models/spatracker/models/core/__init__.py +5 -0
- models/spatracker/models/core/embeddings.py +250 -0
- models/spatracker/models/core/model_utils.py +477 -0
- models/spatracker/models/core/spatracker/__init__.py +5 -0
- models/spatracker/models/core/spatracker/blocks.py +999 -0
- models/spatracker/models/core/spatracker/dpt/__init__.py +0 -0
- models/spatracker/models/core/spatracker/dpt/base_model.py +16 -0
- models/spatracker/models/core/spatracker/dpt/blocks.py +394 -0
- models/spatracker/models/core/spatracker/dpt/midas_net.py +77 -0
- models/spatracker/models/core/spatracker/dpt/models.py +231 -0
- models/spatracker/models/core/spatracker/dpt/transforms.py +231 -0
- models/spatracker/models/core/spatracker/dpt/vit.py +596 -0
- models/spatracker/models/core/spatracker/feature_net.py +915 -0
- models/spatracker/models/core/spatracker/loftr/__init__.py +1 -0
- models/spatracker/models/core/spatracker/loftr/linear_attention.py +81 -0
- models/spatracker/models/core/spatracker/loftr/transformer.py +142 -0
- models/spatracker/models/core/spatracker/losses.py +90 -0
- models/spatracker/models/core/spatracker/softsplat.py +539 -0
- models/spatracker/models/core/spatracker/spatracker.py +732 -0
- models/spatracker/models/core/spatracker/unet.py +258 -0
- models/spatracker/models/core/spatracker/vit/__init__.py +0 -0
- models/spatracker/models/core/spatracker/vit/common.py +43 -0
- models/spatracker/models/core/spatracker/vit/encoder.py +397 -0
- models/spatracker/predictor.py +284 -0
- models/spatracker/utils/__init__.py +5 -0
- models/spatracker/utils/basic.py +397 -0
- models/spatracker/utils/geom.py +547 -0
- models/spatracker/utils/improc.py +1447 -0
- models/spatracker/utils/misc.py +166 -0
- models/spatracker/utils/samp.py +152 -0
- models/spatracker/utils/visualizer.py +409 -0
- models/spatracker/utils/vox.py +500 -0
- requirements.txt +32 -0
- submodules/MoGe/.gitignore +425 -0
- submodules/MoGe/CHANGELOG.md +15 -0
- submodules/MoGe/CODE_OF_CONDUCT.md +9 -0
- submodules/MoGe/LICENSE +224 -0
.gitignore
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# JetBrains
|
7 |
+
.idea
|
8 |
+
|
9 |
+
# C extensions
|
10 |
+
*.so
|
11 |
+
|
12 |
+
# Distribution / packaging
|
13 |
+
.Python
|
14 |
+
build/
|
15 |
+
develop-eggs/
|
16 |
+
dist/
|
17 |
+
downloads/
|
18 |
+
eggs/
|
19 |
+
.eggs/
|
20 |
+
lib/
|
21 |
+
lib64/
|
22 |
+
parts/
|
23 |
+
sdist/
|
24 |
+
var/
|
25 |
+
wheels/
|
26 |
+
share/python-wheels/
|
27 |
+
*.egg-info/
|
28 |
+
.installed.cfg
|
29 |
+
*.egg
|
30 |
+
MANIFEST
|
31 |
+
|
32 |
+
# PyInstaller
|
33 |
+
# Usually these files are written by a python script from a template
|
34 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
35 |
+
*.manifest
|
36 |
+
*.spec
|
37 |
+
|
38 |
+
# Installer logs
|
39 |
+
pip-log.txt
|
40 |
+
pip-delete-this-directory.txt
|
41 |
+
|
42 |
+
# Unit test / coverage reports
|
43 |
+
htmlcov/
|
44 |
+
.tox/
|
45 |
+
.nox/
|
46 |
+
.coverage
|
47 |
+
.coverage.*
|
48 |
+
.cache
|
49 |
+
nosetests.xml
|
50 |
+
coverage.xml
|
51 |
+
*.cover
|
52 |
+
*.py,cover
|
53 |
+
.hypothesis/
|
54 |
+
.pytest_cache/
|
55 |
+
cover/
|
56 |
+
|
57 |
+
# Translations
|
58 |
+
*.mo
|
59 |
+
*.pot
|
60 |
+
|
61 |
+
# Django stuff:
|
62 |
+
*.log
|
63 |
+
local_settings.py
|
64 |
+
db.sqlite3
|
65 |
+
db.sqlite3-journal
|
66 |
+
|
67 |
+
# Flask stuff:
|
68 |
+
instance/
|
69 |
+
.webassets-cache
|
70 |
+
|
71 |
+
# Scrapy stuff:
|
72 |
+
.scrapy
|
73 |
+
|
74 |
+
# Sphinx documentation
|
75 |
+
docs/_build/
|
76 |
+
|
77 |
+
# PyBuilder
|
78 |
+
.pybuilder/
|
79 |
+
target/
|
80 |
+
|
81 |
+
# Jupyter Notebook
|
82 |
+
.ipynb_checkpoints
|
83 |
+
|
84 |
+
# IPython
|
85 |
+
profile_default/
|
86 |
+
ipython_config.py
|
87 |
+
|
88 |
+
# pyenv
|
89 |
+
# For a library or package, you might want to ignore these files since the code is
|
90 |
+
# intended to run in multiple environments; otherwise, check them in:
|
91 |
+
# .python-version
|
92 |
+
|
93 |
+
# pipenv
|
94 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
95 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
96 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
97 |
+
# install all needed dependencies.
|
98 |
+
#Pipfile.lock
|
99 |
+
|
100 |
+
# poetry
|
101 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
102 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
103 |
+
# commonly ignored for libraries.
|
104 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
105 |
+
#poetry.lock
|
106 |
+
|
107 |
+
# pdm
|
108 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
109 |
+
#pdm.lock
|
110 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
111 |
+
# in version control.
|
112 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
113 |
+
.pdm.toml
|
114 |
+
.pdm-python
|
115 |
+
.pdm-build/
|
116 |
+
|
117 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
118 |
+
__pypackages__/
|
119 |
+
|
120 |
+
# Celery stuff
|
121 |
+
celerybeat-schedule
|
122 |
+
celerybeat.pid
|
123 |
+
|
124 |
+
# SageMath parsed files
|
125 |
+
*.sage.py
|
126 |
+
|
127 |
+
# Environments
|
128 |
+
.env
|
129 |
+
.venv
|
130 |
+
env/
|
131 |
+
venv/
|
132 |
+
ENV/
|
133 |
+
env.bak/
|
134 |
+
venv.bak/
|
135 |
+
|
136 |
+
# Spyder project settings
|
137 |
+
.spyderproject
|
138 |
+
.spyproject
|
139 |
+
|
140 |
+
# Rope project settings
|
141 |
+
.ropeproject
|
142 |
+
|
143 |
+
# mkdocs documentation
|
144 |
+
/site
|
145 |
+
|
146 |
+
# mypy
|
147 |
+
.mypy_cache/
|
148 |
+
.dmypy.json
|
149 |
+
dmypy.json
|
150 |
+
|
151 |
+
# Pyre type checker
|
152 |
+
.pyre/
|
153 |
+
|
154 |
+
# pytype static type analyzer
|
155 |
+
.pytype/
|
156 |
+
|
157 |
+
# Cython debug symbols
|
158 |
+
cython_debug/
|
159 |
+
|
160 |
+
# PyCharm
|
161 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
162 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
163 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
164 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
165 |
+
#.idea/
|
166 |
+
|
167 |
+
# manually added
|
168 |
+
wandb/
|
169 |
+
dump*
|
170 |
+
|
171 |
+
!requirements.txt
|
172 |
+
env/
|
173 |
+
datasets/
|
174 |
+
validation/
|
175 |
+
ckpts/
|
176 |
+
.vscode/
|
177 |
+
output.mp4
|
178 |
+
outputs/
|
179 |
+
camctrl_output
|
180 |
+
*.code-workspace
|
181 |
+
|
182 |
+
**/*/.DS_Store
|
183 |
+
**/*/__pycache__/*
|
184 |
+
.DS_Store
|
185 |
+
__pycache__
|
186 |
+
vis_results
|
187 |
+
checkpoints
|
188 |
+
**/*/.pth
|
189 |
+
**/*/.pt
|
190 |
+
**/*/.mp4
|
191 |
+
**/*/.npy
|
192 |
+
|
193 |
+
/assets/**
|
194 |
+
./vis_results/** */
|
195 |
+
models/monoD/zoeDepth/ckpts/*
|
196 |
+
slurm-*.out
|
197 |
+
.vscode
|
198 |
+
|
199 |
+
data/
|
200 |
+
tmp/
|
.gitmodules
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "submodules/MoGe"]
|
2 |
+
path = submodules/MoGe
|
3 |
+
url = https://github.com/microsoft/MoGe.git
|
app.py
ADDED
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import gradio as gr
|
4 |
+
import torch
|
5 |
+
import subprocess
|
6 |
+
import argparse
|
7 |
+
import glob
|
8 |
+
|
9 |
+
project_root = os.path.dirname(os.path.abspath(__file__))
|
10 |
+
os.environ["GRADIO_TEMP_DIR"] = os.path.join(project_root, "tmp", "gradio")
|
11 |
+
sys.path.append(project_root)
|
12 |
+
|
13 |
+
HERE_PATH = os.path.normpath(os.path.dirname(__file__))
|
14 |
+
sys.path.insert(0, HERE_PATH)
|
15 |
+
from huggingface_hub import hf_hub_download
|
16 |
+
hf_hub_download(repo_id="EXCAI/Diffusion-As-Shader", filename='spatracker/spaT_final.pth', local_dir=f'{HERE_PATH}/checkpoints/')
|
17 |
+
|
18 |
+
|
19 |
+
# Parse command line arguments
|
20 |
+
parser = argparse.ArgumentParser(description="Diffusion as Shader Web UI")
|
21 |
+
parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on")
|
22 |
+
parser.add_argument("--share", action="store_true", help="Share the web UI")
|
23 |
+
parser.add_argument("--gpu", type=int, default=0, help="GPU device ID")
|
24 |
+
parser.add_argument("--model_path", type=str, default="EXCAI/Diffusion-As-Shader", help="Path to model checkpoint")
|
25 |
+
parser.add_argument("--output_dir", type=str, default="tmp", help="Output directory")
|
26 |
+
args = parser.parse_args()
|
27 |
+
|
28 |
+
# Use the original GPU ID throughout the entire code for consistency
|
29 |
+
GPU_ID = args.gpu
|
30 |
+
|
31 |
+
# Set environment variables - this used to remap the GPU, but we're removing this for consistency
|
32 |
+
# Instead, we'll pass the original GPU ID to all commands
|
33 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) # Commented out to ensure consistent GPU ID usage
|
34 |
+
|
35 |
+
# Check if CUDA is available
|
36 |
+
CUDA_AVAILABLE = torch.cuda.is_available()
|
37 |
+
if CUDA_AVAILABLE:
|
38 |
+
GPU_COUNT = torch.cuda.device_count()
|
39 |
+
GPU_NAMES = [f"{i}: {torch.cuda.get_device_name(i)}" for i in range(GPU_COUNT)]
|
40 |
+
else:
|
41 |
+
GPU_COUNT = 0
|
42 |
+
GPU_NAMES = ["CPU (CUDA not available)"]
|
43 |
+
GPU_ID = "CPU"
|
44 |
+
|
45 |
+
DEFAULT_MODEL_PATH = args.model_path
|
46 |
+
OUTPUT_DIR = args.output_dir
|
47 |
+
|
48 |
+
# Create necessary directories
|
49 |
+
os.makedirs("outputs", exist_ok=True)
|
50 |
+
# Create project tmp directory instead of using system temp
|
51 |
+
os.makedirs(os.path.join(project_root, "tmp"), exist_ok=True)
|
52 |
+
os.makedirs(os.path.join(project_root, "tmp", "gradio"), exist_ok=True)
|
53 |
+
|
54 |
+
def save_uploaded_file(file):
|
55 |
+
if file is None:
|
56 |
+
return None
|
57 |
+
|
58 |
+
# Use project tmp directory instead of system temp
|
59 |
+
temp_dir = os.path.join(project_root, "tmp")
|
60 |
+
|
61 |
+
if hasattr(file, 'name'):
|
62 |
+
filename = file.name
|
63 |
+
else:
|
64 |
+
# Generate a unique filename if name attribute is missing
|
65 |
+
import uuid
|
66 |
+
ext = ".tmp"
|
67 |
+
if hasattr(file, 'content_type'):
|
68 |
+
if "image" in file.content_type:
|
69 |
+
ext = ".png"
|
70 |
+
elif "video" in file.content_type:
|
71 |
+
ext = ".mp4"
|
72 |
+
filename = f"{uuid.uuid4()}{ext}"
|
73 |
+
|
74 |
+
temp_path = os.path.join(temp_dir, filename)
|
75 |
+
|
76 |
+
try:
|
77 |
+
# Check if file is a FileStorage object or already a path
|
78 |
+
if hasattr(file, 'save'):
|
79 |
+
file.save(temp_path)
|
80 |
+
elif isinstance(file, str):
|
81 |
+
# It's already a path
|
82 |
+
return file
|
83 |
+
else:
|
84 |
+
# Try to read and save the file
|
85 |
+
with open(temp_path, 'wb') as f:
|
86 |
+
f.write(file.read() if hasattr(file, 'read') else file)
|
87 |
+
except Exception as e:
|
88 |
+
print(f"Error saving file: {e}")
|
89 |
+
return None
|
90 |
+
|
91 |
+
return temp_path
|
92 |
+
|
93 |
+
def create_run_command(args):
|
94 |
+
"""Create command based on input parameters"""
|
95 |
+
cmd = ["python", "demo.py"]
|
96 |
+
|
97 |
+
if "prompt" not in args or args["prompt"] is None or args["prompt"] == "":
|
98 |
+
args["prompt"] = ""
|
99 |
+
if "checkpoint_path" not in args or args["checkpoint_path"] is None or args["checkpoint_path"] == "":
|
100 |
+
args["checkpoint_path"] = DEFAULT_MODEL_PATH
|
101 |
+
|
102 |
+
# 添加调试输出
|
103 |
+
print(f"DEBUG: Command args: {args}")
|
104 |
+
|
105 |
+
for key, value in args.items():
|
106 |
+
if value is not None:
|
107 |
+
# Handle boolean values correctly - for repaint, we need to pass true/false
|
108 |
+
if isinstance(value, bool):
|
109 |
+
cmd.append(f"--{key}")
|
110 |
+
cmd.append(str(value).lower()) # Convert True/False to true/false
|
111 |
+
else:
|
112 |
+
cmd.append(f"--{key}")
|
113 |
+
cmd.append(str(value))
|
114 |
+
|
115 |
+
return cmd
|
116 |
+
|
117 |
+
def run_process(cmd):
|
118 |
+
"""Run command and return output"""
|
119 |
+
print(f"Running command: {' '.join(cmd)}")
|
120 |
+
process = subprocess.Popen(
|
121 |
+
cmd,
|
122 |
+
stdout=subprocess.PIPE,
|
123 |
+
stderr=subprocess.PIPE,
|
124 |
+
universal_newlines=True
|
125 |
+
)
|
126 |
+
|
127 |
+
output = []
|
128 |
+
for line in iter(process.stdout.readline, ""):
|
129 |
+
print(line, end="")
|
130 |
+
output.append(line)
|
131 |
+
if not line:
|
132 |
+
break
|
133 |
+
|
134 |
+
process.stdout.close()
|
135 |
+
return_code = process.wait()
|
136 |
+
|
137 |
+
if return_code:
|
138 |
+
stderr = process.stderr.read()
|
139 |
+
print(f"Error: {stderr}")
|
140 |
+
raise subprocess.CalledProcessError(return_code, cmd, output="\n".join(output), stderr=stderr)
|
141 |
+
|
142 |
+
return "\n".join(output)
|
143 |
+
|
144 |
+
# Process functions for each tab
|
145 |
+
def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image):
|
146 |
+
"""Process video motion transfer task"""
|
147 |
+
try:
|
148 |
+
# Save uploaded files
|
149 |
+
input_video_path = save_uploaded_file(source)
|
150 |
+
if input_video_path is None:
|
151 |
+
return None
|
152 |
+
|
153 |
+
print(f"DEBUG: Repaint option: {mt_repaint_option}")
|
154 |
+
print(f"DEBUG: Repaint image: {mt_repaint_image}")
|
155 |
+
|
156 |
+
args = {
|
157 |
+
"input_path": input_video_path,
|
158 |
+
"prompt": f"\"{prompt}\"",
|
159 |
+
"checkpoint_path": DEFAULT_MODEL_PATH,
|
160 |
+
"output_dir": OUTPUT_DIR,
|
161 |
+
"gpu": GPU_ID
|
162 |
+
}
|
163 |
+
|
164 |
+
# Priority: Custom Image > Yes > No
|
165 |
+
if mt_repaint_image is not None:
|
166 |
+
# Custom image takes precedence if provided
|
167 |
+
repaint_path = save_uploaded_file(mt_repaint_image)
|
168 |
+
print(f"DEBUG: Repaint path: {repaint_path}")
|
169 |
+
args["repaint"] = repaint_path
|
170 |
+
elif mt_repaint_option == "Yes":
|
171 |
+
# Otherwise use Yes/No selection
|
172 |
+
args["repaint"] = "true"
|
173 |
+
|
174 |
+
# Create and run command
|
175 |
+
cmd = create_run_command(args)
|
176 |
+
output = run_process(cmd)
|
177 |
+
|
178 |
+
# Find generated video files
|
179 |
+
output_files = glob.glob(os.path.join(OUTPUT_DIR, "*.mp4"))
|
180 |
+
if output_files:
|
181 |
+
# Sort by modification time, return the latest file
|
182 |
+
latest_file = max(output_files, key=os.path.getmtime)
|
183 |
+
return latest_file
|
184 |
+
else:
|
185 |
+
return None
|
186 |
+
except Exception as e:
|
187 |
+
import traceback
|
188 |
+
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
189 |
+
return None
|
190 |
+
|
191 |
+
def process_camera_control(source, prompt, camera_motion, tracking_method):
|
192 |
+
"""Process camera control task"""
|
193 |
+
try:
|
194 |
+
# Save uploaded files
|
195 |
+
input_media_path = save_uploaded_file(source)
|
196 |
+
if input_media_path is None:
|
197 |
+
return None
|
198 |
+
|
199 |
+
print(f"DEBUG: Camera motion: '{camera_motion}'")
|
200 |
+
print(f"DEBUG: Tracking method: '{tracking_method}'")
|
201 |
+
|
202 |
+
args = {
|
203 |
+
"input_path": input_media_path,
|
204 |
+
"prompt": prompt,
|
205 |
+
"checkpoint_path": DEFAULT_MODEL_PATH,
|
206 |
+
"output_dir": OUTPUT_DIR,
|
207 |
+
"gpu": GPU_ID,
|
208 |
+
"tracking_method": tracking_method
|
209 |
+
}
|
210 |
+
|
211 |
+
if camera_motion and camera_motion.strip():
|
212 |
+
args["camera_motion"] = camera_motion
|
213 |
+
|
214 |
+
# Create and run command
|
215 |
+
cmd = create_run_command(args)
|
216 |
+
output = run_process(cmd)
|
217 |
+
|
218 |
+
# Find generated video files
|
219 |
+
output_files = glob.glob(os.path.join(OUTPUT_DIR, "*.mp4"))
|
220 |
+
if output_files:
|
221 |
+
# Sort by modification time, return the latest file
|
222 |
+
latest_file = max(output_files, key=os.path.getmtime)
|
223 |
+
return latest_file
|
224 |
+
else:
|
225 |
+
return None
|
226 |
+
except Exception as e:
|
227 |
+
import traceback
|
228 |
+
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
229 |
+
return None
|
230 |
+
|
231 |
+
def process_object_manipulation(source, prompt, object_motion, object_mask, tracking_method):
|
232 |
+
"""Process object manipulation task"""
|
233 |
+
try:
|
234 |
+
# Save uploaded files
|
235 |
+
input_image_path = save_uploaded_file(source)
|
236 |
+
if input_image_path is None:
|
237 |
+
return None
|
238 |
+
|
239 |
+
object_mask_path = save_uploaded_file(object_mask)
|
240 |
+
|
241 |
+
args = {
|
242 |
+
"input_path": input_image_path,
|
243 |
+
"prompt": prompt,
|
244 |
+
"checkpoint_path": DEFAULT_MODEL_PATH,
|
245 |
+
"output_dir": OUTPUT_DIR,
|
246 |
+
"gpu": GPU_ID,
|
247 |
+
"object_motion": object_motion,
|
248 |
+
"object_mask": object_mask_path,
|
249 |
+
"tracking_method": tracking_method
|
250 |
+
}
|
251 |
+
|
252 |
+
# Create and run command
|
253 |
+
cmd = create_run_command(args)
|
254 |
+
output = run_process(cmd)
|
255 |
+
|
256 |
+
# Find generated video files
|
257 |
+
output_files = glob.glob(os.path.join(OUTPUT_DIR, "*.mp4"))
|
258 |
+
if output_files:
|
259 |
+
# Sort by modification time, return the latest file
|
260 |
+
latest_file = max(output_files, key=os.path.getmtime)
|
261 |
+
return latest_file
|
262 |
+
else:
|
263 |
+
return None
|
264 |
+
except Exception as e:
|
265 |
+
import traceback
|
266 |
+
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
267 |
+
return None
|
268 |
+
|
269 |
+
def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma_repaint_image):
|
270 |
+
"""Process mesh animation task"""
|
271 |
+
try:
|
272 |
+
# Save uploaded files
|
273 |
+
input_video_path = save_uploaded_file(source)
|
274 |
+
if input_video_path is None:
|
275 |
+
return None
|
276 |
+
|
277 |
+
tracking_video_path = save_uploaded_file(tracking_video)
|
278 |
+
if tracking_video_path is None:
|
279 |
+
return None
|
280 |
+
|
281 |
+
args = {
|
282 |
+
"input_path": input_video_path,
|
283 |
+
"prompt": prompt,
|
284 |
+
"checkpoint_path": DEFAULT_MODEL_PATH,
|
285 |
+
"output_dir": OUTPUT_DIR,
|
286 |
+
"gpu": GPU_ID,
|
287 |
+
"tracking_path": tracking_video_path
|
288 |
+
}
|
289 |
+
|
290 |
+
# Priority: Custom Image > Yes > No
|
291 |
+
if ma_repaint_image is not None:
|
292 |
+
# Custom image takes precedence if provided
|
293 |
+
repaint_path = save_uploaded_file(ma_repaint_image)
|
294 |
+
args["repaint"] = repaint_path
|
295 |
+
elif ma_repaint_option == "Yes":
|
296 |
+
# Otherwise use Yes/No selection
|
297 |
+
args["repaint"] = "true"
|
298 |
+
|
299 |
+
# Create and run command
|
300 |
+
cmd = create_run_command(args)
|
301 |
+
output = run_process(cmd)
|
302 |
+
|
303 |
+
# Find generated video files
|
304 |
+
output_files = glob.glob(os.path.join(OUTPUT_DIR, "*.mp4"))
|
305 |
+
if output_files:
|
306 |
+
# Sort by modification time, return the latest file
|
307 |
+
latest_file = max(output_files, key=os.path.getmtime)
|
308 |
+
return latest_file
|
309 |
+
else:
|
310 |
+
return None
|
311 |
+
except Exception as e:
|
312 |
+
import traceback
|
313 |
+
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
314 |
+
return None
|
315 |
+
|
316 |
+
# Create Gradio interface with updated layout
|
317 |
+
with gr.Blocks(title="Diffusion as Shader") as demo:
|
318 |
+
gr.Markdown("# Diffusion as Shader Web UI")
|
319 |
+
gr.Markdown("### [Project Page](https://igl-hkust.github.io/das/) | [GitHub](https://github.com/IGL-HKUST/DiffusionAsShader)")
|
320 |
+
|
321 |
+
with gr.Row():
|
322 |
+
left_column = gr.Column(scale=1)
|
323 |
+
right_column = gr.Column(scale=1)
|
324 |
+
|
325 |
+
with right_column:
|
326 |
+
output_video = gr.Video(label="Generated Video")
|
327 |
+
|
328 |
+
with left_column:
|
329 |
+
source = gr.File(label="Source", file_types=["image", "video"])
|
330 |
+
common_prompt = gr.Textbox(label="Prompt", lines=2)
|
331 |
+
gr.Markdown(f"**Using GPU: {GPU_ID}**")
|
332 |
+
|
333 |
+
with gr.Tabs() as task_tabs:
|
334 |
+
# Motion Transfer tab
|
335 |
+
with gr.TabItem("Motion Transfer"):
|
336 |
+
gr.Markdown("## Motion Transfer")
|
337 |
+
|
338 |
+
# Simplified controls - Radio buttons for Yes/No and separate file upload
|
339 |
+
with gr.Row():
|
340 |
+
mt_repaint_option = gr.Radio(
|
341 |
+
label="Repaint First Frame",
|
342 |
+
choices=["No", "Yes"],
|
343 |
+
value="No"
|
344 |
+
)
|
345 |
+
gr.Markdown("### Note: If you want to use your own image as repainted first frame, please upload the image in below.")
|
346 |
+
# Custom image uploader (always visible)
|
347 |
+
mt_repaint_image = gr.File(
|
348 |
+
label="Custom Repaint Image",
|
349 |
+
file_types=["image"]
|
350 |
+
)
|
351 |
+
|
352 |
+
# Add run button for Motion Transfer tab
|
353 |
+
mt_run_btn = gr.Button("Run Motion Transfer", variant="primary", size="lg")
|
354 |
+
|
355 |
+
# Connect to process function
|
356 |
+
mt_run_btn.click(
|
357 |
+
fn=process_motion_transfer,
|
358 |
+
inputs=[
|
359 |
+
source, common_prompt,
|
360 |
+
mt_repaint_option, mt_repaint_image
|
361 |
+
],
|
362 |
+
outputs=[output_video]
|
363 |
+
)
|
364 |
+
|
365 |
+
# Camera Control tab
|
366 |
+
with gr.TabItem("Camera Control"):
|
367 |
+
gr.Markdown("## Camera Control")
|
368 |
+
|
369 |
+
cc_camera_motion = gr.Textbox(
|
370 |
+
label="Current Camera Motion Sequence",
|
371 |
+
placeholder="Your camera motion sequence will appear here...",
|
372 |
+
interactive=False
|
373 |
+
)
|
374 |
+
|
375 |
+
# Use tabs for different motion types
|
376 |
+
with gr.Tabs() as cc_motion_tabs:
|
377 |
+
# Translation tab
|
378 |
+
with gr.TabItem("Translation (trans)"):
|
379 |
+
with gr.Row():
|
380 |
+
cc_trans_x = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="X-axis Movement")
|
381 |
+
cc_trans_y = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="Y-axis Movement")
|
382 |
+
cc_trans_z = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="Z-axis Movement (depth)")
|
383 |
+
|
384 |
+
with gr.Row():
|
385 |
+
cc_trans_start = gr.Number(minimum=0, maximum=48, value=0, step=1, label="Start Frame", precision=0)
|
386 |
+
cc_trans_end = gr.Number(minimum=0, maximum=48, value=48, step=1, label="End Frame", precision=0)
|
387 |
+
|
388 |
+
cc_trans_note = gr.Markdown("""
|
389 |
+
**Translation Notes:**
|
390 |
+
- Positive X: Move right, Negative X: Move left
|
391 |
+
- Positive Y: Move down, Negative Y: Move up
|
392 |
+
- Positive Z: Zoom in, Negative Z: Zoom out
|
393 |
+
""")
|
394 |
+
|
395 |
+
# Add translation button in the Translation tab
|
396 |
+
cc_add_trans = gr.Button("Add Camera Translation", variant="secondary")
|
397 |
+
|
398 |
+
# Function to add translation motion
|
399 |
+
def add_translation_motion(current_motion, trans_x, trans_y, trans_z, trans_start, trans_end):
|
400 |
+
# Format: trans dx dy dz [start_frame end_frame]
|
401 |
+
frame_range = f" {int(trans_start)} {int(trans_end)}" if trans_start != 0 or trans_end != 48 else ""
|
402 |
+
new_motion = f"trans {trans_x:.2f} {trans_y:.2f} {trans_z:.2f}{frame_range}"
|
403 |
+
|
404 |
+
# Append to existing motion string with semicolon separator if needed
|
405 |
+
if current_motion and current_motion.strip():
|
406 |
+
updated_motion = f"{current_motion}; {new_motion}"
|
407 |
+
else:
|
408 |
+
updated_motion = new_motion
|
409 |
+
|
410 |
+
return updated_motion
|
411 |
+
|
412 |
+
# Connect translation button
|
413 |
+
cc_add_trans.click(
|
414 |
+
fn=add_translation_motion,
|
415 |
+
inputs=[
|
416 |
+
cc_camera_motion,
|
417 |
+
cc_trans_x, cc_trans_y, cc_trans_z, cc_trans_start, cc_trans_end
|
418 |
+
],
|
419 |
+
outputs=[cc_camera_motion]
|
420 |
+
)
|
421 |
+
|
422 |
+
# Rotation tab
|
423 |
+
with gr.TabItem("Rotation (rot)"):
|
424 |
+
with gr.Row():
|
425 |
+
cc_rot_axis = gr.Dropdown(choices=["x", "y", "z"], value="y", label="Rotation Axis")
|
426 |
+
cc_rot_angle = gr.Slider(minimum=-30, maximum=30, value=5, step=1, label="Rotation Angle (degrees)")
|
427 |
+
|
428 |
+
with gr.Row():
|
429 |
+
cc_rot_start = gr.Number(minimum=0, maximum=48, value=0, step=1, label="Start Frame", precision=0)
|
430 |
+
cc_rot_end = gr.Number(minimum=0, maximum=48, value=48, step=1, label="End Frame", precision=0)
|
431 |
+
|
432 |
+
cc_rot_note = gr.Markdown("""
|
433 |
+
**Rotation Notes:**
|
434 |
+
- X-axis rotation: Tilt camera up/down
|
435 |
+
- Y-axis rotation: Pan camera left/right
|
436 |
+
- Z-axis rotation: Roll camera
|
437 |
+
""")
|
438 |
+
|
439 |
+
# Add rotation button in the Rotation tab
|
440 |
+
cc_add_rot = gr.Button("Add Camera Rotation", variant="secondary")
|
441 |
+
|
442 |
+
# Function to add rotation motion
|
443 |
+
def add_rotation_motion(current_motion, rot_axis, rot_angle, rot_start, rot_end):
|
444 |
+
# Format: rot axis angle [start_frame end_frame]
|
445 |
+
frame_range = f" {int(rot_start)} {int(rot_end)}" if rot_start != 0 or rot_end != 48 else ""
|
446 |
+
new_motion = f"rot {rot_axis} {rot_angle}{frame_range}"
|
447 |
+
|
448 |
+
# Append to existing motion string with semicolon separator if needed
|
449 |
+
if current_motion and current_motion.strip():
|
450 |
+
updated_motion = f"{current_motion}; {new_motion}"
|
451 |
+
else:
|
452 |
+
updated_motion = new_motion
|
453 |
+
|
454 |
+
return updated_motion
|
455 |
+
|
456 |
+
# Connect rotation button
|
457 |
+
cc_add_rot.click(
|
458 |
+
fn=add_rotation_motion,
|
459 |
+
inputs=[
|
460 |
+
cc_camera_motion,
|
461 |
+
cc_rot_axis, cc_rot_angle, cc_rot_start, cc_rot_end
|
462 |
+
],
|
463 |
+
outputs=[cc_camera_motion]
|
464 |
+
)
|
465 |
+
|
466 |
+
# Add a clear button to reset the motion sequence
|
467 |
+
cc_clear_motion = gr.Button("Clear All Motions", variant="stop")
|
468 |
+
|
469 |
+
def clear_camera_motion():
|
470 |
+
return ""
|
471 |
+
|
472 |
+
cc_clear_motion.click(
|
473 |
+
fn=clear_camera_motion,
|
474 |
+
inputs=[],
|
475 |
+
outputs=[cc_camera_motion]
|
476 |
+
)
|
477 |
+
|
478 |
+
cc_tracking_method = gr.Radio(
|
479 |
+
label="Tracking Method",
|
480 |
+
choices=["spatracker", "moge"],
|
481 |
+
value="moge"
|
482 |
+
)
|
483 |
+
|
484 |
+
# Add run button for Camera Control tab
|
485 |
+
cc_run_btn = gr.Button("Run Camera Control", variant="primary", size="lg")
|
486 |
+
|
487 |
+
# Connect to process function
|
488 |
+
cc_run_btn.click(
|
489 |
+
fn=process_camera_control,
|
490 |
+
inputs=[
|
491 |
+
source, common_prompt,
|
492 |
+
cc_camera_motion, cc_tracking_method
|
493 |
+
],
|
494 |
+
outputs=[output_video]
|
495 |
+
)
|
496 |
+
|
497 |
+
# Object Manipulation tab
|
498 |
+
with gr.TabItem("Object Manipulation"):
|
499 |
+
gr.Markdown("## Object Manipulation")
|
500 |
+
om_object_mask = gr.File(
|
501 |
+
label="Object Mask Image",
|
502 |
+
file_types=["image"]
|
503 |
+
)
|
504 |
+
gr.Markdown("Upload a binary mask image, white areas indicate the object to manipulate")
|
505 |
+
om_object_motion = gr.Dropdown(
|
506 |
+
label="Object Motion Type",
|
507 |
+
choices=["up", "down", "left", "right", "front", "back", "rot"],
|
508 |
+
value="up"
|
509 |
+
)
|
510 |
+
om_tracking_method = gr.Radio(
|
511 |
+
label="Tracking Method",
|
512 |
+
choices=["spatracker", "moge"],
|
513 |
+
value="moge"
|
514 |
+
)
|
515 |
+
|
516 |
+
# Add run button for Object Manipulation tab
|
517 |
+
om_run_btn = gr.Button("Run Object Manipulation", variant="primary", size="lg")
|
518 |
+
|
519 |
+
# Connect to process function
|
520 |
+
om_run_btn.click(
|
521 |
+
fn=process_object_manipulation,
|
522 |
+
inputs=[
|
523 |
+
source, common_prompt,
|
524 |
+
om_object_motion, om_object_mask, om_tracking_method
|
525 |
+
],
|
526 |
+
outputs=[output_video]
|
527 |
+
)
|
528 |
+
|
529 |
+
# Animating meshes to video tab
|
530 |
+
with gr.TabItem("Animating meshes to video"):
|
531 |
+
gr.Markdown("## Mesh Animation to Video")
|
532 |
+
gr.Markdown("""
|
533 |
+
Note: Currently only supports tracking videos generated with Blender (version > 4.0).
|
534 |
+
Please run the script `scripts/blender.py` in your Blender project to generate tracking videos.
|
535 |
+
""")
|
536 |
+
ma_tracking_video = gr.File(
|
537 |
+
label="Tracking Video",
|
538 |
+
file_types=["video"]
|
539 |
+
)
|
540 |
+
gr.Markdown("Tracking video needs to be generated from Blender")
|
541 |
+
|
542 |
+
# Simplified controls - Radio buttons for Yes/No and separate file upload
|
543 |
+
with gr.Row():
|
544 |
+
ma_repaint_option = gr.Radio(
|
545 |
+
label="Repaint First Frame",
|
546 |
+
choices=["No", "Yes"],
|
547 |
+
value="No"
|
548 |
+
)
|
549 |
+
gr.Markdown("### Note: If you want to use your own image as repainted first frame, please upload the image in below.")
|
550 |
+
# Custom image uploader (always visible)
|
551 |
+
ma_repaint_image = gr.File(
|
552 |
+
label="Custom Repaint Image",
|
553 |
+
file_types=["image"]
|
554 |
+
)
|
555 |
+
|
556 |
+
# Add run button for Mesh Animation tab
|
557 |
+
ma_run_btn = gr.Button("Run Mesh Animation", variant="primary", size="lg")
|
558 |
+
|
559 |
+
# Connect to process function
|
560 |
+
ma_run_btn.click(
|
561 |
+
fn=process_mesh_animation,
|
562 |
+
inputs=[
|
563 |
+
source, common_prompt,
|
564 |
+
ma_tracking_video, ma_repaint_option, ma_repaint_image
|
565 |
+
],
|
566 |
+
outputs=[output_video]
|
567 |
+
)
|
568 |
+
|
569 |
+
# Launch interface
|
570 |
+
if __name__ == "__main__":
|
571 |
+
print(f"Using GPU: {GPU_ID}")
|
572 |
+
print(f"Web UI will start on port {args.port}")
|
573 |
+
if args.share:
|
574 |
+
print("Creating public link for remote access")
|
575 |
+
|
576 |
+
# Launch interface
|
577 |
+
demo.launch(share=args.share, server_port=args.port)
|
config/__init__.py
ADDED
File without changes
|
config/base_cfg.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#python3.10
|
2 |
+
"""Hierachical configuration for different pipelines, using `yacs`
|
3 |
+
(refered to https://github.com/rbgirshick/yacs)
|
4 |
+
|
5 |
+
This projects contain the configuration for three aspects:
|
6 |
+
the regular config for experiment setting
|
7 |
+
|
8 |
+
NOTE: Each experiment will be assigned a seperate working space, and the
|
9 |
+
intermediate results will be saved in the working space. The experimentes
|
10 |
+
folder structure is as follows:
|
11 |
+
{
|
12 |
+
/${ROOT_WORK_DIR}/
|
13 |
+
└── ${PIPELINES_NAME}/
|
14 |
+
├── ${EXP_NAME}/
|
15 |
+
├── ${CHECKPOINT_DIR}/
|
16 |
+
├── ${RESULT_DIR}/
|
17 |
+
├── meta.json/
|
18 |
+
└── ${LOG_DIR}
|
19 |
+
}
|
20 |
+
|
21 |
+
"""
|
22 |
+
|
23 |
+
import os, sys
|
24 |
+
from .yacs import CfgNode as CN
|
25 |
+
import argparse
|
26 |
+
import numpy as np
|
27 |
+
|
28 |
+
# the parser for boolean
|
29 |
+
def bool_parser(arg):
|
30 |
+
"""Parses an argument to boolean."""
|
31 |
+
if isinstance(arg, bool):
|
32 |
+
return arg
|
33 |
+
if arg is None:
|
34 |
+
return False
|
35 |
+
if arg.lower() in ['1', 'true', 't', 'yes', 'y']:
|
36 |
+
return True
|
37 |
+
if arg.lower() in ['0', 'false', 'f', 'no', 'n']:
|
38 |
+
return False
|
39 |
+
raise ValueError(f'`{arg}` cannot be converted to boolean!')
|
40 |
+
|
41 |
+
# -----------------------------------------------------------------------------
|
42 |
+
# base cfg
|
43 |
+
# -----------------------------------------------------------------------------
|
44 |
+
cfg = CN()
|
45 |
+
|
46 |
+
# configuration for basic experiments
|
47 |
+
cfg.save_dir = "./checkpoints"
|
48 |
+
cfg.restore_ckpt = ""
|
49 |
+
cfg.model_name = "cotracker"
|
50 |
+
cfg.exp_name = ""
|
51 |
+
|
52 |
+
# NOTE: configuration for datasets and augmentation
|
53 |
+
cfg.dataset_root = ""
|
54 |
+
cfg.eval_datasets = [""]
|
55 |
+
cfg.dont_use_augs = False
|
56 |
+
cfg.crop_size = [384, 512]
|
57 |
+
cfg.traj_per_sample = 384
|
58 |
+
cfg.sample_vis_1st_frame = False
|
59 |
+
cfg.depth_near = 0.01 # meter
|
60 |
+
cfg.depth_far = 65.0 # meter
|
61 |
+
cfg.sequence_len = 24
|
62 |
+
|
63 |
+
# NOTE: configuration for network arch
|
64 |
+
cfg.sliding_window_len = 8
|
65 |
+
cfg.remove_space_attn = False
|
66 |
+
cfg.updateformer_hidden_size = 384
|
67 |
+
cfg.updateformer_num_heads = 8
|
68 |
+
cfg.updateformer_space_depth = 6
|
69 |
+
cfg.updateformer_time_depth = 6
|
70 |
+
cfg.model_stride = 4
|
71 |
+
cfg.train_iters = 4
|
72 |
+
cfg.if_ARAP = False
|
73 |
+
cfg.Embed3D = False
|
74 |
+
cfg.Loss_W_feat = 5e-1
|
75 |
+
cfg.Loss_W_cls = 1e-4
|
76 |
+
cfg.depth_color = False
|
77 |
+
cfg.flash_attn = False
|
78 |
+
cfg.corr_dp = True
|
79 |
+
cfg.support_grid = 0
|
80 |
+
cfg.backbone = "CNN"
|
81 |
+
cfg.enc_only = False
|
82 |
+
cfg.init_match = False
|
83 |
+
cfg.Nblock = 4
|
84 |
+
|
85 |
+
# NOTE: configuration for training and saving
|
86 |
+
cfg.nodes_num = 1
|
87 |
+
cfg.batch_size = 1
|
88 |
+
cfg.num_workers = 6
|
89 |
+
cfg.mixed_precision = False
|
90 |
+
cfg.lr = 0.0005
|
91 |
+
cfg.wdecay = 0.00001
|
92 |
+
cfg.num_steps = 200000
|
93 |
+
cfg.evaluate_every_n_epoch = 1
|
94 |
+
cfg.save_every_n_epoch = 1
|
95 |
+
cfg.validate_at_start = False
|
96 |
+
cfg.save_freq = 100
|
97 |
+
cfg.eval_max_seq_len = 1000
|
98 |
+
cfg.debug = False
|
99 |
+
cfg.fine_tune = False
|
100 |
+
cfg.aug_wind_sample = False
|
101 |
+
cfg.use_video_flip = False
|
102 |
+
cfg.fix_backbone = False
|
103 |
+
cfg.tune_backbone = False
|
104 |
+
cfg.tune_arap = False
|
105 |
+
cfg.tune_per_scene = False
|
106 |
+
cfg.use_hier_encoder = False
|
107 |
+
cfg.scales = [4, 2]
|
108 |
+
|
109 |
+
|
110 |
+
# NOTE: configuration for monocular depth estimator
|
111 |
+
cfg.mde_name = "zoedepth_nk"
|
112 |
+
|
113 |
+
# -----------------------------------------------------------------------------
|
114 |
+
|
115 |
+
# configurations for the command line
|
116 |
+
parser = argparse.ArgumentParser()
|
117 |
+
|
118 |
+
# config for the basic experiment
|
119 |
+
parser.add_argument("--save_dir", default="./checkpoints", type=str ,help="path to save checkpoints")
|
120 |
+
parser.add_argument("--restore_ckpt", default="", help="path to restore a checkpoint")
|
121 |
+
parser.add_argument("--model_name", default="cotracker", help="model name")
|
122 |
+
parser.add_argument("--exp_name", type=str, default="base",
|
123 |
+
help="the name for experiment",
|
124 |
+
)
|
125 |
+
# config for dataset and augmentation
|
126 |
+
parser.add_argument(
|
127 |
+
"--dataset_root", type=str, help="path lo all the datasets (train and eval)"
|
128 |
+
)
|
129 |
+
parser.add_argument(
|
130 |
+
"--eval_datasets", nargs="+", default=["things", "badja"],
|
131 |
+
help="what datasets to use for evaluation",
|
132 |
+
)
|
133 |
+
parser.add_argument(
|
134 |
+
"--dont_use_augs", action="store_true", default=False,
|
135 |
+
help="don't apply augmentations during training",
|
136 |
+
)
|
137 |
+
parser.add_argument(
|
138 |
+
"--crop_size", type=int, nargs="+", default=[384, 512],
|
139 |
+
help="crop videos to this resolution during training",
|
140 |
+
)
|
141 |
+
parser.add_argument(
|
142 |
+
"--traj_per_sample", type=int, default=768,
|
143 |
+
help="the number of trajectories to sample for training",
|
144 |
+
)
|
145 |
+
parser.add_argument(
|
146 |
+
"--depth_near", type=float, default=0.01, help="near plane depth"
|
147 |
+
)
|
148 |
+
parser.add_argument(
|
149 |
+
"--depth_far", type=float, default=65.0, help="far plane depth"
|
150 |
+
)
|
151 |
+
parser.add_argument(
|
152 |
+
"--sample_vis_1st_frame",
|
153 |
+
action="store_true",
|
154 |
+
default=False,
|
155 |
+
help="only sample trajectories with points visible on the first frame",
|
156 |
+
)
|
157 |
+
parser.add_argument(
|
158 |
+
"--sequence_len", type=int, default=24, help="train sequence length"
|
159 |
+
)
|
160 |
+
# configuration for network arch
|
161 |
+
parser.add_argument(
|
162 |
+
"--sliding_window_len",
|
163 |
+
type=int,
|
164 |
+
default=8,
|
165 |
+
help="length of the CoTracker sliding window",
|
166 |
+
)
|
167 |
+
parser.add_argument(
|
168 |
+
"--remove_space_attn",
|
169 |
+
action="store_true",
|
170 |
+
default=False,
|
171 |
+
help="remove space attention from CoTracker",
|
172 |
+
)
|
173 |
+
parser.add_argument(
|
174 |
+
"--updateformer_hidden_size",
|
175 |
+
type=int,
|
176 |
+
default=384,
|
177 |
+
help="hidden dimension of the CoTracker transformer model",
|
178 |
+
)
|
179 |
+
parser.add_argument(
|
180 |
+
"--updateformer_num_heads",
|
181 |
+
type=int,
|
182 |
+
default=8,
|
183 |
+
help="number of heads of the CoTracker transformer model",
|
184 |
+
)
|
185 |
+
parser.add_argument(
|
186 |
+
"--updateformer_space_depth",
|
187 |
+
type=int,
|
188 |
+
default=6,
|
189 |
+
help="number of group attention layers in the CoTracker transformer model",
|
190 |
+
)
|
191 |
+
parser.add_argument(
|
192 |
+
"--updateformer_time_depth",
|
193 |
+
type=int,
|
194 |
+
default=6,
|
195 |
+
help="number of time attention layers in the CoTracker transformer model",
|
196 |
+
)
|
197 |
+
parser.add_argument(
|
198 |
+
"--model_stride",
|
199 |
+
type=int,
|
200 |
+
default=4,
|
201 |
+
help="stride of the CoTracker feature network",
|
202 |
+
)
|
203 |
+
parser.add_argument(
|
204 |
+
"--train_iters",
|
205 |
+
type=int,
|
206 |
+
default=4,
|
207 |
+
help="number of updates to the disparity field in each forward pass.",
|
208 |
+
)
|
209 |
+
parser.add_argument(
|
210 |
+
"--if_ARAP",
|
211 |
+
action="store_true",
|
212 |
+
default=False,
|
213 |
+
help="if using ARAP loss in the optimization",
|
214 |
+
)
|
215 |
+
parser.add_argument(
|
216 |
+
"--Embed3D",
|
217 |
+
action="store_true",
|
218 |
+
default=False,
|
219 |
+
help="if using the 3D embedding for image",
|
220 |
+
)
|
221 |
+
parser.add_argument(
|
222 |
+
"--Loss_W_feat",
|
223 |
+
type=float,
|
224 |
+
default=5e-1,
|
225 |
+
help="weight for the feature loss",
|
226 |
+
)
|
227 |
+
parser.add_argument(
|
228 |
+
"--Loss_W_cls",
|
229 |
+
type=float,
|
230 |
+
default=1e-4,
|
231 |
+
help="weight for the classification loss",
|
232 |
+
)
|
233 |
+
parser.add_argument(
|
234 |
+
"--depth_color",
|
235 |
+
action="store_true",
|
236 |
+
default=False,
|
237 |
+
help="if using the color for depth",
|
238 |
+
)
|
239 |
+
parser.add_argument(
|
240 |
+
"--flash_attn",
|
241 |
+
action="store_true",
|
242 |
+
default=False,
|
243 |
+
help="if using the flash attention",
|
244 |
+
)
|
245 |
+
parser.add_argument(
|
246 |
+
"--corr_dp",
|
247 |
+
action="store_true",
|
248 |
+
default=False,
|
249 |
+
help="if using the correlation of depth",
|
250 |
+
)
|
251 |
+
parser.add_argument(
|
252 |
+
"--support_grid",
|
253 |
+
type=int,
|
254 |
+
default=0,
|
255 |
+
help="if using the support grid",
|
256 |
+
)
|
257 |
+
parser.add_argument(
|
258 |
+
"--backbone",
|
259 |
+
type=str,
|
260 |
+
default="CNN",
|
261 |
+
help="backbone for the CoTracker feature network",
|
262 |
+
)
|
263 |
+
parser.add_argument(
|
264 |
+
"--enc_only",
|
265 |
+
action="store_true",
|
266 |
+
default=False,
|
267 |
+
help="if using the encoder only",
|
268 |
+
)
|
269 |
+
parser.add_argument(
|
270 |
+
"--init_match",
|
271 |
+
action="store_true",
|
272 |
+
default=False,
|
273 |
+
help="if using the initial matching",
|
274 |
+
)
|
275 |
+
parser.add_argument(
|
276 |
+
"--Nblock",
|
277 |
+
type=int,
|
278 |
+
default=4,
|
279 |
+
help="number of blocks in the CoTracker feature network",
|
280 |
+
)
|
281 |
+
|
282 |
+
# configuration for training and saving
|
283 |
+
parser.add_argument(
|
284 |
+
"--nodes_num", type=int, default=1, help="number of nodes used for training."
|
285 |
+
)
|
286 |
+
parser.add_argument(
|
287 |
+
"--batch_size", type=int, default=1, help="batch size used during training."
|
288 |
+
)
|
289 |
+
parser.add_argument(
|
290 |
+
"--num_workers", type=int, default=6, help="number of dataloader workers"
|
291 |
+
)
|
292 |
+
|
293 |
+
parser.add_argument(
|
294 |
+
"--mixed_precision",
|
295 |
+
action="store_true", default=False,
|
296 |
+
help="use mixed precision"
|
297 |
+
)
|
298 |
+
parser.add_argument("--lr", type=float, default=0.0005, help="max learning rate.")
|
299 |
+
parser.add_argument(
|
300 |
+
"--wdecay", type=float, default=0.00001, help="Weight decay in optimizer."
|
301 |
+
)
|
302 |
+
parser.add_argument(
|
303 |
+
"--num_steps", type=int, default=200000, help="length of training schedule."
|
304 |
+
)
|
305 |
+
parser.add_argument(
|
306 |
+
"--evaluate_every_n_epoch",
|
307 |
+
type=int,
|
308 |
+
default=1,
|
309 |
+
help="evaluate during training after every n epochs, after every epoch by default",
|
310 |
+
)
|
311 |
+
parser.add_argument(
|
312 |
+
"--save_every_n_epoch",
|
313 |
+
type=int,
|
314 |
+
default=1,
|
315 |
+
help="save checkpoints during training after every n epochs, after every epoch by default",
|
316 |
+
)
|
317 |
+
parser.add_argument(
|
318 |
+
"--validate_at_start",
|
319 |
+
action="store_true",
|
320 |
+
default=False,
|
321 |
+
help="whether to run evaluation before training starts",
|
322 |
+
)
|
323 |
+
parser.add_argument(
|
324 |
+
"--save_freq",
|
325 |
+
type=int,
|
326 |
+
default=100,
|
327 |
+
help="frequency of trajectory visualization during training",
|
328 |
+
)
|
329 |
+
parser.add_argument(
|
330 |
+
"--eval_max_seq_len",
|
331 |
+
type=int,
|
332 |
+
default=1000,
|
333 |
+
help="maximum length of evaluation videos",
|
334 |
+
)
|
335 |
+
parser.add_argument(
|
336 |
+
"--debug",
|
337 |
+
action="store_true",
|
338 |
+
default=False,
|
339 |
+
help="if using the visibility mask",
|
340 |
+
)
|
341 |
+
parser.add_argument(
|
342 |
+
"--fine_tune",
|
343 |
+
action="store_true",
|
344 |
+
default=False,
|
345 |
+
help="if fine tune the model",
|
346 |
+
)
|
347 |
+
parser.add_argument(
|
348 |
+
"--aug_wind_sample",
|
349 |
+
action="store_true",
|
350 |
+
default=False,
|
351 |
+
help="if using the window sampling",
|
352 |
+
)
|
353 |
+
parser.add_argument(
|
354 |
+
"--use_video_flip",
|
355 |
+
action="store_true",
|
356 |
+
default=False,
|
357 |
+
help="if using the video flip",
|
358 |
+
)
|
359 |
+
parser.add_argument(
|
360 |
+
"--fix_backbone",
|
361 |
+
action="store_true",
|
362 |
+
default=False,
|
363 |
+
help="if fix the backbone",
|
364 |
+
)
|
365 |
+
parser.add_argument(
|
366 |
+
"--tune_backbone",
|
367 |
+
action="store_true",
|
368 |
+
default=False,
|
369 |
+
help="if tune the backbone",
|
370 |
+
)
|
371 |
+
parser.add_argument(
|
372 |
+
"--tune_arap",
|
373 |
+
action="store_true",
|
374 |
+
default=False,
|
375 |
+
help="if fix the backbone",
|
376 |
+
)
|
377 |
+
parser.add_argument(
|
378 |
+
"--tune_per_scene",
|
379 |
+
action="store_true",
|
380 |
+
default=False,
|
381 |
+
help="if tune one scene",
|
382 |
+
)
|
383 |
+
parser.add_argument(
|
384 |
+
"--use_hier_encoder",
|
385 |
+
action="store_true",
|
386 |
+
default=False,
|
387 |
+
help="if using the hierarchical encoder",
|
388 |
+
)
|
389 |
+
parser.add_argument(
|
390 |
+
"--scales",
|
391 |
+
type=int,
|
392 |
+
nargs="+",
|
393 |
+
default=[4, 2],
|
394 |
+
help="scales for the CoTracker feature network",
|
395 |
+
)
|
396 |
+
|
397 |
+
# config for monocular depth estimator
|
398 |
+
parser.add_argument(
|
399 |
+
"--mde_name", type=str, default="zoedepth_nk", help="name of the MDE model"
|
400 |
+
)
|
401 |
+
args = parser.parse_args()
|
402 |
+
args_dict = vars(args)
|
403 |
+
|
404 |
+
# -----------------------------------------------------------------------------
|
405 |
+
|
406 |
+
# merge the `args` to the `cfg`
|
407 |
+
cfg.merge_from_dict(args_dict)
|
408 |
+
|
409 |
+
cfg.ckpt_path=os.path.join(args.save_dir, args.model_name ,args.exp_name)
|
410 |
+
|
config/ssm_cfg.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#python3.10
|
2 |
+
"""Hierachical configuration for different pipelines, using `yacs`
|
3 |
+
(refered to https://github.com/rbgirshick/yacs)
|
4 |
+
|
5 |
+
This projects contain the configuration for three aspects:
|
6 |
+
the regular config for experiment setting
|
7 |
+
|
8 |
+
NOTE: Each experiment will be assigned a seperate working space, and the
|
9 |
+
intermediate results will be saved in the working space. The experimentes
|
10 |
+
folder structure is as follows:
|
11 |
+
{
|
12 |
+
/${ROOT_WORK_DIR}/
|
13 |
+
└── ${PIPELINES_NAME}/
|
14 |
+
├── ${EXP_NAME}/
|
15 |
+
├── ${CHECKPOINT_DIR}/
|
16 |
+
├── ${RESULT_DIR}/
|
17 |
+
├── meta.json/
|
18 |
+
└── ${LOG_DIR}
|
19 |
+
}
|
20 |
+
|
21 |
+
"""
|
22 |
+
|
23 |
+
import os, sys
|
24 |
+
from .yacs import CfgNode as CN
|
25 |
+
import argparse
|
26 |
+
import numpy as np
|
27 |
+
|
28 |
+
# the parser for boolean
|
29 |
+
def bool_parser(arg):
|
30 |
+
"""Parses an argument to boolean."""
|
31 |
+
if isinstance(arg, bool):
|
32 |
+
return arg
|
33 |
+
if arg is None:
|
34 |
+
return False
|
35 |
+
if arg.lower() in ['1', 'true', 't', 'yes', 'y']:
|
36 |
+
return True
|
37 |
+
if arg.lower() in ['0', 'false', 'f', 'no', 'n']:
|
38 |
+
return False
|
39 |
+
raise ValueError(f'`{arg}` cannot be converted to boolean!')
|
40 |
+
|
41 |
+
# -----------------------------------------------------------------------------
|
42 |
+
# base cfg
|
43 |
+
# -----------------------------------------------------------------------------
|
44 |
+
cfg = CN()
|
45 |
+
|
46 |
+
# configuration for basic experiments
|
47 |
+
cfg.save_dir = "./checkpoints"
|
48 |
+
cfg.restore_ckpt = ""
|
49 |
+
cfg.model_name = "cotracker"
|
50 |
+
cfg.exp_name = ""
|
51 |
+
|
52 |
+
# NOTE: configuration for datasets and augmentation
|
53 |
+
cfg.dataset_root = ""
|
54 |
+
cfg.eval_datasets = [""]
|
55 |
+
cfg.dont_use_augs = False
|
56 |
+
cfg.crop_size = [384, 512]
|
57 |
+
cfg.traj_per_sample = 384
|
58 |
+
cfg.sample_vis_1st_frame = False
|
59 |
+
cfg.depth_near = 0.01 # meter
|
60 |
+
cfg.depth_far = 65.0 # meter
|
61 |
+
cfg.sequence_len = 24
|
62 |
+
|
63 |
+
# NOTE: configuration for network arch
|
64 |
+
cfg.hidden_size = 384
|
65 |
+
cfg.mamba_depth = 8
|
66 |
+
cfg.model_stride = 4
|
67 |
+
cfg.train_iters = 4
|
68 |
+
cfg.updateformer_num_heads = 8
|
69 |
+
cfg.updateformer_hidden_size = 384
|
70 |
+
cfg.if_ARAP = False
|
71 |
+
cfg.Embed3D = False
|
72 |
+
cfg.Loss_W_feat = 5e-1
|
73 |
+
cfg.Loss_W_cls = 1e-4
|
74 |
+
cfg.depth_color = False
|
75 |
+
cfg.flash_attn = False
|
76 |
+
cfg.corr_dp = True
|
77 |
+
cfg.support_grid = 0
|
78 |
+
cfg.backbone = "CNN"
|
79 |
+
cfg.enc_only = False
|
80 |
+
|
81 |
+
# NOTE: configuration for training and saving
|
82 |
+
cfg.nodes_num = 1
|
83 |
+
cfg.batch_size = 1
|
84 |
+
cfg.num_workers = 6
|
85 |
+
cfg.mixed_precision = False
|
86 |
+
cfg.lr = 0.0005
|
87 |
+
cfg.wdecay = 0.00001
|
88 |
+
cfg.num_steps = 200000
|
89 |
+
cfg.evaluate_every_n_epoch = 1
|
90 |
+
cfg.save_every_n_epoch = 1
|
91 |
+
cfg.validate_at_start = False
|
92 |
+
cfg.save_freq = 100
|
93 |
+
cfg.eval_max_seq_len = 1000
|
94 |
+
cfg.debug = False
|
95 |
+
cfg.fine_tune = False
|
96 |
+
cfg.aug_wind_sample = False
|
97 |
+
cfg.use_video_flip = False
|
98 |
+
cfg.fix_backbone = False
|
99 |
+
cfg.tune_backbone = False
|
100 |
+
|
101 |
+
|
102 |
+
# NOTE: configuration for monocular depth estimator
|
103 |
+
cfg.mde_name = "zoedepth_nk"
|
104 |
+
|
105 |
+
# -----------------------------------------------------------------------------
|
106 |
+
|
107 |
+
# configurations for the command line
|
108 |
+
parser = argparse.ArgumentParser()
|
109 |
+
|
110 |
+
# config for the basic experiment
|
111 |
+
parser.add_argument("--save_dir", default="./checkpoints", type=str ,help="path to save checkpoints")
|
112 |
+
parser.add_argument("--restore_ckpt", default="", help="path to restore a checkpoint")
|
113 |
+
parser.add_argument("--model_name", default="cotracker", help="model name")
|
114 |
+
parser.add_argument("--exp_name", type=str, default="base",
|
115 |
+
help="the name for experiment",
|
116 |
+
)
|
117 |
+
# config for dataset and augmentation
|
118 |
+
parser.add_argument(
|
119 |
+
"--dataset_root", type=str, help="path lo all the datasets (train and eval)"
|
120 |
+
)
|
121 |
+
parser.add_argument(
|
122 |
+
"--eval_datasets", nargs="+", default=["things", "badja"],
|
123 |
+
help="what datasets to use for evaluation",
|
124 |
+
)
|
125 |
+
parser.add_argument(
|
126 |
+
"--dont_use_augs", action="store_true", default=False,
|
127 |
+
help="don't apply augmentations during training",
|
128 |
+
)
|
129 |
+
parser.add_argument(
|
130 |
+
"--crop_size", type=int, nargs="+", default=[384, 512],
|
131 |
+
help="crop videos to this resolution during training",
|
132 |
+
)
|
133 |
+
parser.add_argument(
|
134 |
+
"--traj_per_sample", type=int, default=768,
|
135 |
+
help="the number of trajectories to sample for training",
|
136 |
+
)
|
137 |
+
parser.add_argument(
|
138 |
+
"--depth_near", type=float, default=0.01, help="near plane depth"
|
139 |
+
)
|
140 |
+
parser.add_argument(
|
141 |
+
"--depth_far", type=float, default=65.0, help="far plane depth"
|
142 |
+
)
|
143 |
+
parser.add_argument(
|
144 |
+
"--sample_vis_1st_frame",
|
145 |
+
action="store_true",
|
146 |
+
default=False,
|
147 |
+
help="only sample trajectories with points visible on the first frame",
|
148 |
+
)
|
149 |
+
parser.add_argument(
|
150 |
+
"--sequence_len", type=int, default=24, help="train sequence length"
|
151 |
+
)
|
152 |
+
# configuration for network arch
|
153 |
+
parser.add_argument(
|
154 |
+
"--hidden_size",
|
155 |
+
type=int,
|
156 |
+
default=384,
|
157 |
+
help="hidden dimension of the CoTracker transformer model",
|
158 |
+
)
|
159 |
+
parser.add_argument(
|
160 |
+
"--mamba_depth",
|
161 |
+
type=int,
|
162 |
+
default=6,
|
163 |
+
help="number of group attention layers in the CoTracker transformer model",
|
164 |
+
)
|
165 |
+
parser.add_argument(
|
166 |
+
"--updateformer_num_heads",
|
167 |
+
type=int,
|
168 |
+
default=8,
|
169 |
+
help="number of heads of the CoTracker transformer model",
|
170 |
+
)
|
171 |
+
parser.add_argument(
|
172 |
+
"--updateformer_hidden_size",
|
173 |
+
type=int,
|
174 |
+
default=384,
|
175 |
+
help="hidden dimension of the CoTracker transformer model",
|
176 |
+
)
|
177 |
+
parser.add_argument(
|
178 |
+
"--model_stride",
|
179 |
+
type=int,
|
180 |
+
default=4,
|
181 |
+
help="stride of the CoTracker feature network",
|
182 |
+
)
|
183 |
+
parser.add_argument(
|
184 |
+
"--train_iters",
|
185 |
+
type=int,
|
186 |
+
default=4,
|
187 |
+
help="number of updates to the disparity field in each forward pass.",
|
188 |
+
)
|
189 |
+
parser.add_argument(
|
190 |
+
"--if_ARAP",
|
191 |
+
action="store_true",
|
192 |
+
default=False,
|
193 |
+
help="if using ARAP loss in the optimization",
|
194 |
+
)
|
195 |
+
parser.add_argument(
|
196 |
+
"--Embed3D",
|
197 |
+
action="store_true",
|
198 |
+
default=False,
|
199 |
+
help="if using the 3D embedding for image",
|
200 |
+
)
|
201 |
+
parser.add_argument(
|
202 |
+
"--Loss_W_feat",
|
203 |
+
type=float,
|
204 |
+
default=5e-1,
|
205 |
+
help="weight for the feature loss",
|
206 |
+
)
|
207 |
+
parser.add_argument(
|
208 |
+
"--Loss_W_cls",
|
209 |
+
type=float,
|
210 |
+
default=1e-4,
|
211 |
+
help="weight for the classification loss",
|
212 |
+
)
|
213 |
+
parser.add_argument(
|
214 |
+
"--depth_color",
|
215 |
+
action="store_true",
|
216 |
+
default=False,
|
217 |
+
help="if using the color for depth",
|
218 |
+
)
|
219 |
+
parser.add_argument(
|
220 |
+
"--flash_attn",
|
221 |
+
action="store_true",
|
222 |
+
default=False,
|
223 |
+
help="if using the flash attention",
|
224 |
+
)
|
225 |
+
parser.add_argument(
|
226 |
+
"--corr_dp",
|
227 |
+
action="store_true",
|
228 |
+
default=False,
|
229 |
+
help="if using the correlation of depth",
|
230 |
+
)
|
231 |
+
parser.add_argument(
|
232 |
+
"--support_grid",
|
233 |
+
type=int,
|
234 |
+
default=0,
|
235 |
+
help="if using the support grid",
|
236 |
+
)
|
237 |
+
parser.add_argument(
|
238 |
+
"--backbone",
|
239 |
+
type=str,
|
240 |
+
default="CNN",
|
241 |
+
help="backbone for the CoTracker feature network",
|
242 |
+
)
|
243 |
+
parser.add_argument(
|
244 |
+
"--enc_only",
|
245 |
+
action="store_true",
|
246 |
+
default=False,
|
247 |
+
help="if using the encoder only",
|
248 |
+
)
|
249 |
+
|
250 |
+
# configuration for training and saving
|
251 |
+
parser.add_argument(
|
252 |
+
"--nodes_num", type=int, default=1, help="number of nodes used for training."
|
253 |
+
)
|
254 |
+
parser.add_argument(
|
255 |
+
"--batch_size", type=int, default=1, help="batch size used during training."
|
256 |
+
)
|
257 |
+
parser.add_argument(
|
258 |
+
"--num_workers", type=int, default=6, help="number of dataloader workers"
|
259 |
+
)
|
260 |
+
|
261 |
+
parser.add_argument(
|
262 |
+
"--mixed_precision",
|
263 |
+
action="store_true", default=False,
|
264 |
+
help="use mixed precision"
|
265 |
+
)
|
266 |
+
parser.add_argument("--lr", type=float, default=0.0005, help="max learning rate.")
|
267 |
+
parser.add_argument(
|
268 |
+
"--wdecay", type=float, default=0.00001, help="Weight decay in optimizer."
|
269 |
+
)
|
270 |
+
parser.add_argument(
|
271 |
+
"--num_steps", type=int, default=200000, help="length of training schedule."
|
272 |
+
)
|
273 |
+
parser.add_argument(
|
274 |
+
"--evaluate_every_n_epoch",
|
275 |
+
type=int,
|
276 |
+
default=1,
|
277 |
+
help="evaluate during training after every n epochs, after every epoch by default",
|
278 |
+
)
|
279 |
+
parser.add_argument(
|
280 |
+
"--save_every_n_epoch",
|
281 |
+
type=int,
|
282 |
+
default=1,
|
283 |
+
help="save checkpoints during training after every n epochs, after every epoch by default",
|
284 |
+
)
|
285 |
+
parser.add_argument(
|
286 |
+
"--validate_at_start",
|
287 |
+
action="store_true",
|
288 |
+
default=False,
|
289 |
+
help="whether to run evaluation before training starts",
|
290 |
+
)
|
291 |
+
parser.add_argument(
|
292 |
+
"--save_freq",
|
293 |
+
type=int,
|
294 |
+
default=100,
|
295 |
+
help="frequency of trajectory visualization during training",
|
296 |
+
)
|
297 |
+
parser.add_argument(
|
298 |
+
"--eval_max_seq_len",
|
299 |
+
type=int,
|
300 |
+
default=1000,
|
301 |
+
help="maximum length of evaluation videos",
|
302 |
+
)
|
303 |
+
parser.add_argument(
|
304 |
+
"--debug",
|
305 |
+
action="store_true",
|
306 |
+
default=False,
|
307 |
+
help="if using the visibility mask",
|
308 |
+
)
|
309 |
+
parser.add_argument(
|
310 |
+
"--fine_tune",
|
311 |
+
action="store_true",
|
312 |
+
default=False,
|
313 |
+
help="if fine tune the model",
|
314 |
+
)
|
315 |
+
parser.add_argument(
|
316 |
+
"--aug_wind_sample",
|
317 |
+
action="store_true",
|
318 |
+
default=False,
|
319 |
+
help="if using the window sampling",
|
320 |
+
)
|
321 |
+
parser.add_argument(
|
322 |
+
"--use_video_flip",
|
323 |
+
action="store_true",
|
324 |
+
default=False,
|
325 |
+
help="if using the video flip",
|
326 |
+
)
|
327 |
+
parser.add_argument(
|
328 |
+
"--fix_backbone",
|
329 |
+
action="store_true",
|
330 |
+
default=False,
|
331 |
+
help="if fix the backbone",
|
332 |
+
)
|
333 |
+
|
334 |
+
# config for monocular depth estimator
|
335 |
+
parser.add_argument(
|
336 |
+
"--mde_name", type=str, default="zoedepth_nk", help="name of the MDE model"
|
337 |
+
)
|
338 |
+
args = parser.parse_args()
|
339 |
+
args_dict = vars(args)
|
340 |
+
|
341 |
+
# -----------------------------------------------------------------------------
|
342 |
+
|
343 |
+
# merge the `args` to the `cfg`
|
344 |
+
cfg.merge_from_dict(args_dict)
|
345 |
+
|
346 |
+
cfg.ckpt_path=os.path.join(args.save_dir, args.model_name ,args.exp_name)
|
347 |
+
|
config/yacs.py
ADDED
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2018-present, Facebook, Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
##############################################################################
|
15 |
+
|
16 |
+
"""YACS -- Yet Another Configuration System is designed to be a simple
|
17 |
+
configuration management system for academic and industrial research
|
18 |
+
projects.
|
19 |
+
|
20 |
+
See README.md for usage and examples.
|
21 |
+
"""
|
22 |
+
|
23 |
+
import copy
|
24 |
+
import io
|
25 |
+
import logging
|
26 |
+
import os
|
27 |
+
from ast import literal_eval
|
28 |
+
|
29 |
+
import yaml
|
30 |
+
|
31 |
+
|
32 |
+
# Flag for py2 and py3 compatibility to use when separate code paths are necessary
|
33 |
+
# When _PY2 is False, we assume Python 3 is in use
|
34 |
+
_PY2 = False
|
35 |
+
|
36 |
+
# Filename extensions for loading configs from files
|
37 |
+
_YAML_EXTS = {"", ".yaml", ".yml"}
|
38 |
+
_PY_EXTS = {".py"}
|
39 |
+
|
40 |
+
# py2 and py3 compatibility for checking file object type
|
41 |
+
# We simply use this to infer py2 vs py3
|
42 |
+
try:
|
43 |
+
_FILE_TYPES = (file, io.IOBase)
|
44 |
+
_PY2 = True
|
45 |
+
except NameError:
|
46 |
+
_FILE_TYPES = (io.IOBase,)
|
47 |
+
|
48 |
+
# CfgNodes can only contain a limited set of valid types
|
49 |
+
_VALID_TYPES = {tuple, list, str, int, float, bool}
|
50 |
+
# py2 allow for str and unicode
|
51 |
+
if _PY2:
|
52 |
+
_VALID_TYPES = _VALID_TYPES.union({unicode}) # noqa: F821
|
53 |
+
|
54 |
+
# Utilities for importing modules from file paths
|
55 |
+
if _PY2:
|
56 |
+
# imp is available in both py2 and py3 for now, but is deprecated in py3
|
57 |
+
import imp
|
58 |
+
else:
|
59 |
+
import importlib.util
|
60 |
+
|
61 |
+
logger = logging.getLogger(__name__)
|
62 |
+
|
63 |
+
|
64 |
+
class CfgNode(dict):
|
65 |
+
"""
|
66 |
+
CfgNode represents an internal node in the configuration tree. It's a simple
|
67 |
+
dict-like container that allows for attribute-based access to keys.
|
68 |
+
"""
|
69 |
+
|
70 |
+
IMMUTABLE = "__immutable__"
|
71 |
+
DEPRECATED_KEYS = "__deprecated_keys__"
|
72 |
+
RENAMED_KEYS = "__renamed_keys__"
|
73 |
+
|
74 |
+
def __init__(self, init_dict=None, key_list=None):
|
75 |
+
# Recursively convert nested dictionaries in init_dict into CfgNodes
|
76 |
+
init_dict = {} if init_dict is None else init_dict
|
77 |
+
key_list = [] if key_list is None else key_list
|
78 |
+
for k, v in init_dict.items():
|
79 |
+
if type(v) is dict:
|
80 |
+
# Convert dict to CfgNode
|
81 |
+
init_dict[k] = CfgNode(v, key_list=key_list + [k])
|
82 |
+
else:
|
83 |
+
# Check for valid leaf type or nested CfgNode
|
84 |
+
_assert_with_logging(
|
85 |
+
_valid_type(v, allow_cfg_node=True),
|
86 |
+
"Key {} with value {} is not a valid type; valid types: {}".format(
|
87 |
+
".".join(key_list + [k]), type(v), _VALID_TYPES
|
88 |
+
),
|
89 |
+
)
|
90 |
+
super(CfgNode, self).__init__(init_dict)
|
91 |
+
# Manage if the CfgNode is frozen or not
|
92 |
+
self.__dict__[CfgNode.IMMUTABLE] = False
|
93 |
+
# Deprecated options
|
94 |
+
# If an option is removed from the code and you don't want to break existing
|
95 |
+
# yaml configs, you can add the full config key as a string to the set below.
|
96 |
+
self.__dict__[CfgNode.DEPRECATED_KEYS] = set()
|
97 |
+
# Renamed options
|
98 |
+
# If you rename a config option, record the mapping from the old name to the new
|
99 |
+
# name in the dictionary below. Optionally, if the type also changed, you can
|
100 |
+
# make the value a tuple that specifies first the renamed key and then
|
101 |
+
# instructions for how to edit the config file.
|
102 |
+
self.__dict__[CfgNode.RENAMED_KEYS] = {
|
103 |
+
# 'EXAMPLE.OLD.KEY': 'EXAMPLE.NEW.KEY', # Dummy example to follow
|
104 |
+
# 'EXAMPLE.OLD.KEY': ( # A more complex example to follow
|
105 |
+
# 'EXAMPLE.NEW.KEY',
|
106 |
+
# "Also convert to a tuple, e.g., 'foo' -> ('foo',) or "
|
107 |
+
# + "'foo:bar' -> ('foo', 'bar')"
|
108 |
+
# ),
|
109 |
+
}
|
110 |
+
|
111 |
+
def __getattr__(self, name):
|
112 |
+
if name in self:
|
113 |
+
return self[name]
|
114 |
+
else:
|
115 |
+
raise AttributeError(name)
|
116 |
+
|
117 |
+
def __setattr__(self, name, value):
|
118 |
+
if self.is_frozen():
|
119 |
+
raise AttributeError(
|
120 |
+
"Attempted to set {} to {}, but CfgNode is immutable".format(
|
121 |
+
name, value
|
122 |
+
)
|
123 |
+
)
|
124 |
+
|
125 |
+
_assert_with_logging(
|
126 |
+
name not in self.__dict__,
|
127 |
+
"Invalid attempt to modify internal CfgNode state: {}".format(name),
|
128 |
+
)
|
129 |
+
_assert_with_logging(
|
130 |
+
_valid_type(value, allow_cfg_node=True),
|
131 |
+
"Invalid type {} for key {}; valid types = {}".format(
|
132 |
+
type(value), name, _VALID_TYPES
|
133 |
+
),
|
134 |
+
)
|
135 |
+
|
136 |
+
self[name] = value
|
137 |
+
|
138 |
+
def __str__(self):
|
139 |
+
def _indent(s_, num_spaces):
|
140 |
+
s = s_.split("\n")
|
141 |
+
if len(s) == 1:
|
142 |
+
return s_
|
143 |
+
first = s.pop(0)
|
144 |
+
s = [(num_spaces * " ") + line for line in s]
|
145 |
+
s = "\n".join(s)
|
146 |
+
s = first + "\n" + s
|
147 |
+
return s
|
148 |
+
|
149 |
+
r = ""
|
150 |
+
s = []
|
151 |
+
for k, v in sorted(self.items()):
|
152 |
+
seperator = "\n" if isinstance(v, CfgNode) else " "
|
153 |
+
attr_str = "{}:{}{}".format(str(k), seperator, str(v))
|
154 |
+
attr_str = _indent(attr_str, 2)
|
155 |
+
s.append(attr_str)
|
156 |
+
r += "\n".join(s)
|
157 |
+
return r
|
158 |
+
|
159 |
+
def __repr__(self):
|
160 |
+
return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__())
|
161 |
+
|
162 |
+
def dump(self):
|
163 |
+
"""Dump to a string."""
|
164 |
+
self_as_dict = _to_dict(self)
|
165 |
+
return yaml.safe_dump(self_as_dict)
|
166 |
+
|
167 |
+
def merge_from_file(self, cfg_filename):
|
168 |
+
"""Load a yaml config file and merge it this CfgNode."""
|
169 |
+
with open(cfg_filename, "r") as f:
|
170 |
+
cfg = load_cfg(f)
|
171 |
+
self.merge_from_other_cfg(cfg)
|
172 |
+
|
173 |
+
def merge_from_other_cfg(self, cfg_other):
|
174 |
+
"""Merge `cfg_other` into this CfgNode."""
|
175 |
+
_merge_a_into_b(cfg_other, self, self, [])
|
176 |
+
|
177 |
+
def merge_from_list(self, cfg_list):
|
178 |
+
"""Merge config (keys, values) in a list (e.g., from command line) into
|
179 |
+
this CfgNode. For example, `cfg_list = ['FOO.BAR', 0.5]`.
|
180 |
+
"""
|
181 |
+
_assert_with_logging(
|
182 |
+
len(cfg_list) % 2 == 0,
|
183 |
+
"Override list has odd length: {}; it must be a list of pairs".format(
|
184 |
+
cfg_list
|
185 |
+
),
|
186 |
+
)
|
187 |
+
root = self
|
188 |
+
for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
|
189 |
+
if root.key_is_deprecated(full_key):
|
190 |
+
continue
|
191 |
+
if root.key_is_renamed(full_key):
|
192 |
+
root.raise_key_rename_error(full_key)
|
193 |
+
key_list = full_key.split(".")
|
194 |
+
d = self
|
195 |
+
for subkey in key_list[:-1]:
|
196 |
+
_assert_with_logging(
|
197 |
+
subkey in d, "Non-existent key: {}".format(full_key)
|
198 |
+
)
|
199 |
+
d = d[subkey]
|
200 |
+
subkey = key_list[-1]
|
201 |
+
_assert_with_logging(subkey in d, "Non-existent key: {}".format(full_key))
|
202 |
+
value = _decode_cfg_value(v)
|
203 |
+
value = _check_and_coerce_cfg_value_type(value, d[subkey], subkey, full_key)
|
204 |
+
d[subkey] = value
|
205 |
+
def merge_from_dict(self, cfg_dict):
|
206 |
+
"""Merge config (keys, values) in a dict into this CfgNode."""
|
207 |
+
cfg_dict = cfg_dict.items()
|
208 |
+
cfg_list = []
|
209 |
+
for pair in cfg_dict:
|
210 |
+
cfg_list.append(pair[0])
|
211 |
+
cfg_list.append(pair[1])
|
212 |
+
self.merge_from_list(cfg_list)
|
213 |
+
|
214 |
+
def freeze(self):
|
215 |
+
"""Make this CfgNode and all of its children immutable."""
|
216 |
+
self._immutable(True)
|
217 |
+
|
218 |
+
def defrost(self):
|
219 |
+
"""Make this CfgNode and all of its children mutable."""
|
220 |
+
self._immutable(False)
|
221 |
+
|
222 |
+
def is_frozen(self):
|
223 |
+
"""Return mutability."""
|
224 |
+
return self.__dict__[CfgNode.IMMUTABLE]
|
225 |
+
|
226 |
+
def _immutable(self, is_immutable):
|
227 |
+
"""Set immutability to is_immutable and recursively apply the setting
|
228 |
+
to all nested CfgNodes.
|
229 |
+
"""
|
230 |
+
self.__dict__[CfgNode.IMMUTABLE] = is_immutable
|
231 |
+
# Recursively set immutable state
|
232 |
+
for v in self.__dict__.values():
|
233 |
+
if isinstance(v, CfgNode):
|
234 |
+
v._immutable(is_immutable)
|
235 |
+
for v in self.values():
|
236 |
+
if isinstance(v, CfgNode):
|
237 |
+
v._immutable(is_immutable)
|
238 |
+
|
239 |
+
def clone(self):
|
240 |
+
"""Recursively copy this CfgNode."""
|
241 |
+
return copy.deepcopy(self)
|
242 |
+
|
243 |
+
def register_deprecated_key(self, key):
|
244 |
+
"""Register key (e.g. `FOO.BAR`) a deprecated option. When merging deprecated
|
245 |
+
keys a warning is generated and the key is ignored.
|
246 |
+
"""
|
247 |
+
_assert_with_logging(
|
248 |
+
key not in self.__dict__[CfgNode.DEPRECATED_KEYS],
|
249 |
+
"key {} is already registered as a deprecated key".format(key),
|
250 |
+
)
|
251 |
+
self.__dict__[CfgNode.DEPRECATED_KEYS].add(key)
|
252 |
+
|
253 |
+
def register_renamed_key(self, old_name, new_name, message=None):
|
254 |
+
"""Register a key as having been renamed from `old_name` to `new_name`.
|
255 |
+
When merging a renamed key, an exception is thrown alerting to user to
|
256 |
+
the fact that the key has been renamed.
|
257 |
+
"""
|
258 |
+
_assert_with_logging(
|
259 |
+
old_name not in self.__dict__[CfgNode.RENAMED_KEYS],
|
260 |
+
"key {} is already registered as a renamed cfg key".format(old_name),
|
261 |
+
)
|
262 |
+
value = new_name
|
263 |
+
if message:
|
264 |
+
value = (new_name, message)
|
265 |
+
self.__dict__[CfgNode.RENAMED_KEYS][old_name] = value
|
266 |
+
|
267 |
+
def key_is_deprecated(self, full_key):
|
268 |
+
"""Test if a key is deprecated."""
|
269 |
+
if full_key in self.__dict__[CfgNode.DEPRECATED_KEYS]:
|
270 |
+
logger.warning("Deprecated config key (ignoring): {}".format(full_key))
|
271 |
+
return True
|
272 |
+
return False
|
273 |
+
|
274 |
+
def key_is_renamed(self, full_key):
|
275 |
+
"""Test if a key is renamed."""
|
276 |
+
return full_key in self.__dict__[CfgNode.RENAMED_KEYS]
|
277 |
+
|
278 |
+
def raise_key_rename_error(self, full_key):
|
279 |
+
new_key = self.__dict__[CfgNode.RENAMED_KEYS][full_key]
|
280 |
+
if isinstance(new_key, tuple):
|
281 |
+
msg = " Note: " + new_key[1]
|
282 |
+
new_key = new_key[0]
|
283 |
+
else:
|
284 |
+
msg = ""
|
285 |
+
raise KeyError(
|
286 |
+
"Key {} was renamed to {}; please update your config.{}".format(
|
287 |
+
full_key, new_key, msg
|
288 |
+
)
|
289 |
+
)
|
290 |
+
|
291 |
+
|
292 |
+
def load_cfg(cfg_file_obj_or_str):
|
293 |
+
"""Load a cfg. Supports loading from:
|
294 |
+
- A file object backed by a YAML file
|
295 |
+
- A file object backed by a Python source file that exports an attribute
|
296 |
+
"cfg" that is either a dict or a CfgNode
|
297 |
+
- A string that can be parsed as valid YAML
|
298 |
+
"""
|
299 |
+
_assert_with_logging(
|
300 |
+
isinstance(cfg_file_obj_or_str, _FILE_TYPES + (str,)),
|
301 |
+
"Expected first argument to be of type {} or {}, but it was {}".format(
|
302 |
+
_FILE_TYPES, str, type(cfg_file_obj_or_str)
|
303 |
+
),
|
304 |
+
)
|
305 |
+
if isinstance(cfg_file_obj_or_str, str):
|
306 |
+
return _load_cfg_from_yaml_str(cfg_file_obj_or_str)
|
307 |
+
elif isinstance(cfg_file_obj_or_str, _FILE_TYPES):
|
308 |
+
return _load_cfg_from_file(cfg_file_obj_or_str)
|
309 |
+
else:
|
310 |
+
raise NotImplementedError("Impossible to reach here (unless there's a bug)")
|
311 |
+
|
312 |
+
|
313 |
+
def _load_cfg_from_file(file_obj):
|
314 |
+
"""Load a config from a YAML file or a Python source file."""
|
315 |
+
_, file_extension = os.path.splitext(file_obj.name)
|
316 |
+
if file_extension in _YAML_EXTS:
|
317 |
+
return _load_cfg_from_yaml_str(file_obj.read())
|
318 |
+
elif file_extension in _PY_EXTS:
|
319 |
+
return _load_cfg_py_source(file_obj.name)
|
320 |
+
else:
|
321 |
+
raise Exception(
|
322 |
+
"Attempt to load from an unsupported file type {}; "
|
323 |
+
"only {} are supported".format(file_obj, _YAML_EXTS.union(_PY_EXTS))
|
324 |
+
)
|
325 |
+
|
326 |
+
|
327 |
+
def _load_cfg_from_yaml_str(str_obj):
|
328 |
+
"""Load a config from a YAML string encoding."""
|
329 |
+
cfg_as_dict = yaml.safe_load(str_obj)
|
330 |
+
return CfgNode(cfg_as_dict)
|
331 |
+
|
332 |
+
|
333 |
+
def _load_cfg_py_source(filename):
|
334 |
+
"""Load a config from a Python source file."""
|
335 |
+
module = _load_module_from_file("yacs.config.override", filename)
|
336 |
+
_assert_with_logging(
|
337 |
+
hasattr(module, "cfg"),
|
338 |
+
"Python module from file {} must have 'cfg' attr".format(filename),
|
339 |
+
)
|
340 |
+
VALID_ATTR_TYPES = {dict, CfgNode}
|
341 |
+
_assert_with_logging(
|
342 |
+
type(module.cfg) in VALID_ATTR_TYPES,
|
343 |
+
"Imported module 'cfg' attr must be in {} but is {} instead".format(
|
344 |
+
VALID_ATTR_TYPES, type(module.cfg)
|
345 |
+
),
|
346 |
+
)
|
347 |
+
if type(module.cfg) is dict:
|
348 |
+
return CfgNode(module.cfg)
|
349 |
+
else:
|
350 |
+
return module.cfg
|
351 |
+
|
352 |
+
|
353 |
+
def _to_dict(cfg_node):
|
354 |
+
"""Recursively convert all CfgNode objects to dict objects."""
|
355 |
+
|
356 |
+
def convert_to_dict(cfg_node, key_list):
|
357 |
+
if not isinstance(cfg_node, CfgNode):
|
358 |
+
_assert_with_logging(
|
359 |
+
_valid_type(cfg_node),
|
360 |
+
"Key {} with value {} is not a valid type; valid types: {}".format(
|
361 |
+
".".join(key_list), type(cfg_node), _VALID_TYPES
|
362 |
+
),
|
363 |
+
)
|
364 |
+
return cfg_node
|
365 |
+
else:
|
366 |
+
cfg_dict = dict(cfg_node)
|
367 |
+
for k, v in cfg_dict.items():
|
368 |
+
cfg_dict[k] = convert_to_dict(v, key_list + [k])
|
369 |
+
return cfg_dict
|
370 |
+
|
371 |
+
return convert_to_dict(cfg_node, [])
|
372 |
+
|
373 |
+
|
374 |
+
def _valid_type(value, allow_cfg_node=False):
|
375 |
+
return (type(value) in _VALID_TYPES) or (allow_cfg_node and type(value) == CfgNode)
|
376 |
+
|
377 |
+
|
378 |
+
def _merge_a_into_b(a, b, root, key_list):
|
379 |
+
"""Merge config dictionary a into config dictionary b, clobbering the
|
380 |
+
options in b whenever they are also specified in a.
|
381 |
+
"""
|
382 |
+
_assert_with_logging(
|
383 |
+
isinstance(a, CfgNode),
|
384 |
+
"`a` (cur type {}) must be an instance of {}".format(type(a), CfgNode),
|
385 |
+
)
|
386 |
+
_assert_with_logging(
|
387 |
+
isinstance(b, CfgNode),
|
388 |
+
"`b` (cur type {}) must be an instance of {}".format(type(b), CfgNode),
|
389 |
+
)
|
390 |
+
|
391 |
+
for k, v_ in a.items():
|
392 |
+
full_key = ".".join(key_list + [k])
|
393 |
+
# a must specify keys that are in b
|
394 |
+
if k not in b:
|
395 |
+
if root.key_is_deprecated(full_key):
|
396 |
+
continue
|
397 |
+
elif root.key_is_renamed(full_key):
|
398 |
+
root.raise_key_rename_error(full_key)
|
399 |
+
else:
|
400 |
+
v = copy.deepcopy(v_)
|
401 |
+
v = _decode_cfg_value(v)
|
402 |
+
b.update({k: v})
|
403 |
+
else:
|
404 |
+
v = copy.deepcopy(v_)
|
405 |
+
v = _decode_cfg_value(v)
|
406 |
+
v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key)
|
407 |
+
|
408 |
+
# Recursively merge dicts
|
409 |
+
if isinstance(v, CfgNode):
|
410 |
+
try:
|
411 |
+
_merge_a_into_b(v, b[k], root, key_list + [k])
|
412 |
+
except BaseException:
|
413 |
+
raise
|
414 |
+
else:
|
415 |
+
b[k] = v
|
416 |
+
|
417 |
+
|
418 |
+
def _decode_cfg_value(v):
|
419 |
+
"""Decodes a raw config value (e.g., from a yaml config files or command
|
420 |
+
line argument) into a Python object.
|
421 |
+
"""
|
422 |
+
# Configs parsed from raw yaml will contain dictionary keys that need to be
|
423 |
+
# converted to CfgNode objects
|
424 |
+
if isinstance(v, dict):
|
425 |
+
return CfgNode(v)
|
426 |
+
# All remaining processing is only applied to strings
|
427 |
+
if not isinstance(v, str):
|
428 |
+
return v
|
429 |
+
# Try to interpret `v` as a:
|
430 |
+
# string, number, tuple, list, dict, boolean, or None
|
431 |
+
try:
|
432 |
+
v = literal_eval(v)
|
433 |
+
# The following two excepts allow v to pass through when it represents a
|
434 |
+
# string.
|
435 |
+
#
|
436 |
+
# Longer explanation:
|
437 |
+
# The type of v is always a string (before calling literal_eval), but
|
438 |
+
# sometimes it *represents* a string and other times a data structure, like
|
439 |
+
# a list. In the case that v represents a string, what we got back from the
|
440 |
+
# yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
|
441 |
+
# ok with '"foo"', but will raise a ValueError if given 'foo'. In other
|
442 |
+
# cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
|
443 |
+
# will raise a SyntaxError.
|
444 |
+
except ValueError:
|
445 |
+
pass
|
446 |
+
except SyntaxError:
|
447 |
+
pass
|
448 |
+
return v
|
449 |
+
|
450 |
+
|
451 |
+
def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
|
452 |
+
"""Checks that `replacement`, which is intended to replace `original` is of
|
453 |
+
the right type. The type is correct if it matches exactly or is one of a few
|
454 |
+
cases in which the type can be easily coerced.
|
455 |
+
"""
|
456 |
+
original_type = type(original)
|
457 |
+
replacement_type = type(replacement)
|
458 |
+
|
459 |
+
# The types must match (with some exceptions)
|
460 |
+
if replacement_type == original_type:
|
461 |
+
return replacement
|
462 |
+
|
463 |
+
# Cast replacement from from_type to to_type if the replacement and original
|
464 |
+
# types match from_type and to_type
|
465 |
+
def conditional_cast(from_type, to_type):
|
466 |
+
if replacement_type == from_type and original_type == to_type:
|
467 |
+
return True, to_type(replacement)
|
468 |
+
else:
|
469 |
+
return False, None
|
470 |
+
|
471 |
+
# Conditionally casts
|
472 |
+
# list <-> tuple
|
473 |
+
casts = [(tuple, list), (list, tuple)]
|
474 |
+
# For py2: allow converting from str (bytes) to a unicode string
|
475 |
+
try:
|
476 |
+
casts.append((str, unicode)) # noqa: F821
|
477 |
+
except Exception:
|
478 |
+
pass
|
479 |
+
|
480 |
+
for (from_type, to_type) in casts:
|
481 |
+
converted, converted_value = conditional_cast(from_type, to_type)
|
482 |
+
if converted:
|
483 |
+
return converted_value
|
484 |
+
|
485 |
+
raise ValueError(
|
486 |
+
"Type mismatch ({} vs. {}) with values ({} vs. {}) for config "
|
487 |
+
"key: {}".format(
|
488 |
+
original_type, replacement_type, original, replacement, full_key
|
489 |
+
)
|
490 |
+
)
|
491 |
+
|
492 |
+
|
493 |
+
def _assert_with_logging(cond, msg):
|
494 |
+
if not cond:
|
495 |
+
logger.debug(msg)
|
496 |
+
assert cond, msg
|
497 |
+
|
498 |
+
|
499 |
+
def _load_module_from_file(name, filename):
|
500 |
+
if _PY2:
|
501 |
+
module = imp.load_source(name, filename)
|
502 |
+
else:
|
503 |
+
spec = importlib.util.spec_from_file_location(name, filename)
|
504 |
+
module = importlib.util.module_from_spec(spec)
|
505 |
+
spec.loader.exec_module(module)
|
506 |
+
return module
|
demo.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import argparse
|
4 |
+
from PIL import Image
|
5 |
+
project_root = os.path.dirname(os.path.abspath(__file__))
|
6 |
+
try:
|
7 |
+
sys.path.append(os.path.join(project_root, "submodules/MoGe"))
|
8 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
9 |
+
except:
|
10 |
+
print("Warning: MoGe not found, motion transfer will not be applied")
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import numpy as np
|
14 |
+
from PIL import Image
|
15 |
+
import torchvision.transforms as transforms
|
16 |
+
from moviepy.editor import VideoFileClip
|
17 |
+
from diffusers.utils import load_image, load_video
|
18 |
+
|
19 |
+
from models.pipelines import DiffusionAsShaderPipeline, FirstFrameRepainter, CameraMotionGenerator, ObjectMotionGenerator
|
20 |
+
from submodules.MoGe.moge.model import MoGeModel
|
21 |
+
|
22 |
+
def load_media(media_path, max_frames=49, transform=None):
|
23 |
+
"""Load video or image frames and convert to tensor
|
24 |
+
|
25 |
+
Args:
|
26 |
+
media_path (str): Path to video or image file
|
27 |
+
max_frames (int): Maximum number of frames to load
|
28 |
+
transform (callable): Transform to apply to frames
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
Tuple[torch.Tensor, float]: Video tensor [T,C,H,W] and FPS
|
32 |
+
"""
|
33 |
+
if transform is None:
|
34 |
+
transform = transforms.Compose([
|
35 |
+
transforms.Resize((480, 720)),
|
36 |
+
transforms.ToTensor()
|
37 |
+
])
|
38 |
+
|
39 |
+
# Determine if input is video or image based on extension
|
40 |
+
ext = os.path.splitext(media_path)[1].lower()
|
41 |
+
is_video = ext in ['.mp4', '.avi', '.mov']
|
42 |
+
|
43 |
+
if is_video:
|
44 |
+
frames = load_video(media_path)
|
45 |
+
fps = len(frames) / VideoFileClip(media_path).duration
|
46 |
+
else:
|
47 |
+
# Handle image as single frame
|
48 |
+
image = load_image(media_path)
|
49 |
+
frames = [image]
|
50 |
+
fps = 8 # Default fps for images
|
51 |
+
|
52 |
+
# Ensure we have exactly max_frames
|
53 |
+
if len(frames) > max_frames:
|
54 |
+
frames = frames[:max_frames]
|
55 |
+
elif len(frames) < max_frames:
|
56 |
+
last_frame = frames[-1]
|
57 |
+
while len(frames) < max_frames:
|
58 |
+
frames.append(last_frame.copy())
|
59 |
+
|
60 |
+
# Convert frames to tensor
|
61 |
+
video_tensor = torch.stack([transform(frame) for frame in frames])
|
62 |
+
|
63 |
+
return video_tensor, fps, is_video
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
parser = argparse.ArgumentParser()
|
67 |
+
parser.add_argument('--input_path', type=str, default=None, help='Path to input video/image')
|
68 |
+
parser.add_argument('--prompt', type=str, required=True, help='Repaint prompt')
|
69 |
+
parser.add_argument('--output_dir', type=str, default='outputs', help='Output directory')
|
70 |
+
parser.add_argument('--gpu', type=int, default=0, help='GPU device ID')
|
71 |
+
parser.add_argument('--checkpoint_path', type=str, default="EXCAI/Diffusion-As-Shader", help='Path to model checkpoint')
|
72 |
+
parser.add_argument('--depth_path', type=str, default=None, help='Path to depth image')
|
73 |
+
parser.add_argument('--tracking_path', type=str, default=None, help='Path to tracking video, if provided, camera motion and object manipulation will not be applied')
|
74 |
+
parser.add_argument('--repaint', type=str, default=None,
|
75 |
+
help='Path to repainted image, or "true" to perform repainting, if not provided use original frame')
|
76 |
+
parser.add_argument('--camera_motion', type=str, default=None,
|
77 |
+
help='Camera motion mode: "trans <dx> <dy> <dz>" or "rot <axis> <angle>" or "spiral <radius>"')
|
78 |
+
parser.add_argument('--object_motion', type=str, default=None, help='Object motion mode: up/down/left/right')
|
79 |
+
parser.add_argument('--object_mask', type=str, default=None, help='Path to object mask image (binary image)')
|
80 |
+
parser.add_argument('--tracking_method', type=str, default='spatracker', choices=['spatracker', 'moge'],
|
81 |
+
help='Tracking method to use (spatracker or moge)')
|
82 |
+
args = parser.parse_args()
|
83 |
+
|
84 |
+
# Load input video/image
|
85 |
+
video_tensor, fps, is_video = load_media(args.input_path)
|
86 |
+
if not is_video:
|
87 |
+
args.tracking_method = "moge"
|
88 |
+
print("Image input detected, using MoGe for tracking video generation.")
|
89 |
+
|
90 |
+
# Initialize pipeline
|
91 |
+
das = DiffusionAsShaderPipeline(gpu_id=args.gpu, output_dir=args.output_dir)
|
92 |
+
if args.tracking_method == "moge" and args.tracking_path is None:
|
93 |
+
moge = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(das.device)
|
94 |
+
|
95 |
+
# Repaint first frame if requested
|
96 |
+
repaint_img_tensor = None
|
97 |
+
if args.repaint:
|
98 |
+
if args.repaint.lower() == "true":
|
99 |
+
repainter = FirstFrameRepainter(gpu_id=args.gpu, output_dir=args.output_dir)
|
100 |
+
repaint_img_tensor = repainter.repaint(
|
101 |
+
video_tensor[0],
|
102 |
+
prompt=args.prompt,
|
103 |
+
depth_path=args.depth_path
|
104 |
+
)
|
105 |
+
else:
|
106 |
+
repaint_img_tensor, _, _ = load_media(args.repaint)
|
107 |
+
repaint_img_tensor = repaint_img_tensor[0] # Take first frame
|
108 |
+
|
109 |
+
# Generate tracking if not provided
|
110 |
+
tracking_tensor = None
|
111 |
+
pred_tracks = None
|
112 |
+
cam_motion = CameraMotionGenerator(args.camera_motion)
|
113 |
+
|
114 |
+
if args.tracking_path:
|
115 |
+
tracking_tensor, _, _ = load_media(args.tracking_path)
|
116 |
+
|
117 |
+
elif args.tracking_method == "moge":
|
118 |
+
# Use the first frame from previously loaded video_tensor
|
119 |
+
infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1]
|
120 |
+
H, W = infer_result["points"].shape[0:2]
|
121 |
+
pred_tracks = infer_result["points"].unsqueeze(0).repeat(49, 1, 1, 1) #[T, H, W, 3]
|
122 |
+
cam_motion.set_intr(infer_result["intrinsics"])
|
123 |
+
|
124 |
+
# Apply object motion if specified
|
125 |
+
if args.object_motion:
|
126 |
+
if args.object_mask is None:
|
127 |
+
raise ValueError("Object motion specified but no mask provided. Please provide a mask image with --object_mask")
|
128 |
+
|
129 |
+
# Load mask image
|
130 |
+
mask_image = Image.open(args.object_mask).convert('L') # Convert to grayscale
|
131 |
+
mask_image = transforms.Resize((480, 720))(mask_image) # Resize to match video size
|
132 |
+
# Convert to binary mask
|
133 |
+
mask = torch.from_numpy(np.array(mask_image) > 127) # Threshold at 127
|
134 |
+
|
135 |
+
motion_generator = ObjectMotionGenerator(device=das.device)
|
136 |
+
|
137 |
+
pred_tracks = motion_generator.apply_motion(
|
138 |
+
pred_tracks=pred_tracks,
|
139 |
+
mask=mask,
|
140 |
+
motion_type=args.object_motion,
|
141 |
+
distance=50,
|
142 |
+
num_frames=49,
|
143 |
+
tracking_method="moge"
|
144 |
+
)
|
145 |
+
print("Object motion applied")
|
146 |
+
|
147 |
+
# Apply camera motion if specified
|
148 |
+
if args.camera_motion:
|
149 |
+
poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
|
150 |
+
print("Camera motion applied")
|
151 |
+
else:
|
152 |
+
# no poses
|
153 |
+
poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1)
|
154 |
+
# change pred_tracks into screen coordinate
|
155 |
+
pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3)
|
156 |
+
pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3]
|
157 |
+
_, tracking_tensor = das.visualize_tracking_moge(
|
158 |
+
pred_tracks.cpu().numpy(),
|
159 |
+
infer_result["mask"].cpu().numpy()
|
160 |
+
)
|
161 |
+
print('export tracking video via MoGe.')
|
162 |
+
|
163 |
+
else:
|
164 |
+
# Generate tracking points
|
165 |
+
pred_tracks, pred_visibility, T_Firsts = das.generate_tracking_spatracker(video_tensor)
|
166 |
+
|
167 |
+
# Apply camera motion if specified
|
168 |
+
if args.camera_motion:
|
169 |
+
poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
|
170 |
+
pred_tracks = cam_motion.apply_motion_on_pts(pred_tracks, poses)
|
171 |
+
print("Camera motion applied")
|
172 |
+
|
173 |
+
# Apply object motion if specified
|
174 |
+
if args.object_motion:
|
175 |
+
if args.object_mask is None:
|
176 |
+
raise ValueError("Object motion specified but no mask provided. Please provide a mask image with --object_mask")
|
177 |
+
|
178 |
+
# Load mask image
|
179 |
+
mask_image = Image.open(args.object_mask).convert('L') # Convert to grayscale
|
180 |
+
mask_image = transforms.Resize((480, 720))(mask_image) # Resize to match video size
|
181 |
+
# Convert to binary mask
|
182 |
+
mask = torch.from_numpy(np.array(mask_image) > 127) # Threshold at 127
|
183 |
+
|
184 |
+
motion_generator = ObjectMotionGenerator(device=das.device)
|
185 |
+
|
186 |
+
pred_tracks = motion_generator.apply_motion(
|
187 |
+
pred_tracks=pred_tracks.squeeze(),
|
188 |
+
mask=mask,
|
189 |
+
motion_type=args.object_motion,
|
190 |
+
distance=50,
|
191 |
+
num_frames=49,
|
192 |
+
tracking_method="spatracker"
|
193 |
+
).unsqueeze(0)
|
194 |
+
print(f"Object motion '{args.object_motion}' applied using mask from {args.object_mask}")
|
195 |
+
|
196 |
+
# Generate tracking tensor from modified tracks
|
197 |
+
_, tracking_tensor = das.visualize_tracking_spatracker(video_tensor, pred_tracks, pred_visibility, T_Firsts)
|
198 |
+
|
199 |
+
das.apply_tracking(
|
200 |
+
video_tensor=video_tensor,
|
201 |
+
fps=8,
|
202 |
+
tracking_tensor=tracking_tensor,
|
203 |
+
img_cond_tensor=repaint_img_tensor,
|
204 |
+
prompt=args.prompt,
|
205 |
+
checkpoint_path=args.checkpoint_path
|
206 |
+
)
|
models/cogvideox_tracking.py
ADDED
@@ -0,0 +1,1020 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Optional, Tuple, Union, List, Callable
|
2 |
+
|
3 |
+
import torch, os, math
|
4 |
+
from torch import nn
|
5 |
+
from PIL import Image
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
9 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
10 |
+
from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXBlock, CogVideoXTransformer3DModel
|
11 |
+
|
12 |
+
from diffusers.pipelines.cogvideo.pipeline_cogvideox import CogVideoXPipeline, CogVideoXPipelineOutput
|
13 |
+
from diffusers.pipelines.cogvideo.pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
|
14 |
+
from diffusers.pipelines.cogvideo.pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline
|
15 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
16 |
+
from diffusers.pipelines.cogvideo.pipeline_cogvideox import retrieve_timesteps
|
17 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
18 |
+
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
19 |
+
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
20 |
+
from diffusers.pipelines import DiffusionPipeline
|
21 |
+
from diffusers.models.modeling_utils import ModelMixin
|
22 |
+
|
23 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
24 |
+
|
25 |
+
class CogVideoXTransformer3DModelTracking(CogVideoXTransformer3DModel, ModelMixin):
|
26 |
+
"""
|
27 |
+
Add tracking maps to the CogVideoX transformer model.
|
28 |
+
|
29 |
+
Parameters:
|
30 |
+
num_tracking_blocks (`int`, defaults to `18`):
|
31 |
+
The number of tracking blocks to use. Must be less than or equal to num_layers.
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
num_tracking_blocks: Optional[int] = 18,
|
37 |
+
num_attention_heads: int = 30,
|
38 |
+
attention_head_dim: int = 64,
|
39 |
+
in_channels: int = 16,
|
40 |
+
out_channels: Optional[int] = 16,
|
41 |
+
flip_sin_to_cos: bool = True,
|
42 |
+
freq_shift: int = 0,
|
43 |
+
time_embed_dim: int = 512,
|
44 |
+
text_embed_dim: int = 4096,
|
45 |
+
num_layers: int = 30,
|
46 |
+
dropout: float = 0.0,
|
47 |
+
attention_bias: bool = True,
|
48 |
+
sample_width: int = 90,
|
49 |
+
sample_height: int = 60,
|
50 |
+
sample_frames: int = 49,
|
51 |
+
patch_size: int = 2,
|
52 |
+
temporal_compression_ratio: int = 4,
|
53 |
+
max_text_seq_length: int = 226,
|
54 |
+
activation_fn: str = "gelu-approximate",
|
55 |
+
timestep_activation_fn: str = "silu",
|
56 |
+
norm_elementwise_affine: bool = True,
|
57 |
+
norm_eps: float = 1e-5,
|
58 |
+
spatial_interpolation_scale: float = 1.875,
|
59 |
+
temporal_interpolation_scale: float = 1.0,
|
60 |
+
use_rotary_positional_embeddings: bool = False,
|
61 |
+
use_learned_positional_embeddings: bool = False,
|
62 |
+
**kwargs
|
63 |
+
):
|
64 |
+
super().__init__(
|
65 |
+
num_attention_heads=num_attention_heads,
|
66 |
+
attention_head_dim=attention_head_dim,
|
67 |
+
in_channels=in_channels,
|
68 |
+
out_channels=out_channels,
|
69 |
+
flip_sin_to_cos=flip_sin_to_cos,
|
70 |
+
freq_shift=freq_shift,
|
71 |
+
time_embed_dim=time_embed_dim,
|
72 |
+
text_embed_dim=text_embed_dim,
|
73 |
+
num_layers=num_layers,
|
74 |
+
dropout=dropout,
|
75 |
+
attention_bias=attention_bias,
|
76 |
+
sample_width=sample_width,
|
77 |
+
sample_height=sample_height,
|
78 |
+
sample_frames=sample_frames,
|
79 |
+
patch_size=patch_size,
|
80 |
+
temporal_compression_ratio=temporal_compression_ratio,
|
81 |
+
max_text_seq_length=max_text_seq_length,
|
82 |
+
activation_fn=activation_fn,
|
83 |
+
timestep_activation_fn=timestep_activation_fn,
|
84 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
85 |
+
norm_eps=norm_eps,
|
86 |
+
spatial_interpolation_scale=spatial_interpolation_scale,
|
87 |
+
temporal_interpolation_scale=temporal_interpolation_scale,
|
88 |
+
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
|
89 |
+
use_learned_positional_embeddings=use_learned_positional_embeddings,
|
90 |
+
**kwargs
|
91 |
+
)
|
92 |
+
|
93 |
+
inner_dim = num_attention_heads * attention_head_dim
|
94 |
+
self.num_tracking_blocks = num_tracking_blocks
|
95 |
+
|
96 |
+
# Ensure num_tracking_blocks is not greater than num_layers
|
97 |
+
if num_tracking_blocks > num_layers:
|
98 |
+
raise ValueError("num_tracking_blocks must be less than or equal to num_layers")
|
99 |
+
|
100 |
+
# Create linear layers for combining hidden states and tracking maps
|
101 |
+
self.combine_linears = nn.ModuleList(
|
102 |
+
[nn.Linear(inner_dim, inner_dim) for _ in range(num_tracking_blocks)]
|
103 |
+
)
|
104 |
+
|
105 |
+
# Initialize weights of combine_linears to zero
|
106 |
+
for linear in self.combine_linears:
|
107 |
+
linear.weight.data.zero_()
|
108 |
+
linear.bias.data.zero_()
|
109 |
+
|
110 |
+
# Create transformer blocks for processing tracking maps
|
111 |
+
self.transformer_blocks_copy = nn.ModuleList(
|
112 |
+
[
|
113 |
+
CogVideoXBlock(
|
114 |
+
dim=inner_dim,
|
115 |
+
num_attention_heads=self.config.num_attention_heads,
|
116 |
+
attention_head_dim=self.config.attention_head_dim,
|
117 |
+
time_embed_dim=self.config.time_embed_dim,
|
118 |
+
dropout=self.config.dropout,
|
119 |
+
activation_fn=self.config.activation_fn,
|
120 |
+
attention_bias=self.config.attention_bias,
|
121 |
+
norm_elementwise_affine=self.config.norm_elementwise_affine,
|
122 |
+
norm_eps=self.config.norm_eps,
|
123 |
+
)
|
124 |
+
for _ in range(num_tracking_blocks)
|
125 |
+
]
|
126 |
+
)
|
127 |
+
|
128 |
+
# For initial combination of hidden states and tracking maps
|
129 |
+
self.initial_combine_linear = nn.Linear(inner_dim, inner_dim)
|
130 |
+
self.initial_combine_linear.weight.data.zero_()
|
131 |
+
self.initial_combine_linear.bias.data.zero_()
|
132 |
+
|
133 |
+
# Freeze all parameters
|
134 |
+
for param in self.parameters():
|
135 |
+
param.requires_grad = False
|
136 |
+
|
137 |
+
# Unfreeze parameters that need to be trained
|
138 |
+
for linear in self.combine_linears:
|
139 |
+
for param in linear.parameters():
|
140 |
+
param.requires_grad = True
|
141 |
+
|
142 |
+
for block in self.transformer_blocks_copy:
|
143 |
+
for param in block.parameters():
|
144 |
+
param.requires_grad = True
|
145 |
+
|
146 |
+
for param in self.initial_combine_linear.parameters():
|
147 |
+
param.requires_grad = True
|
148 |
+
|
149 |
+
def forward(
|
150 |
+
self,
|
151 |
+
hidden_states: torch.Tensor,
|
152 |
+
encoder_hidden_states: torch.Tensor,
|
153 |
+
tracking_maps: torch.Tensor,
|
154 |
+
timestep: Union[int, float, torch.LongTensor],
|
155 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
156 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
157 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
158 |
+
return_dict: bool = True,
|
159 |
+
):
|
160 |
+
if attention_kwargs is not None:
|
161 |
+
attention_kwargs = attention_kwargs.copy()
|
162 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
163 |
+
else:
|
164 |
+
lora_scale = 1.0
|
165 |
+
|
166 |
+
if USE_PEFT_BACKEND:
|
167 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
168 |
+
scale_lora_layers(self, lora_scale)
|
169 |
+
else:
|
170 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
171 |
+
logger.warning(
|
172 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
173 |
+
)
|
174 |
+
|
175 |
+
batch_size, num_frames, channels, height, width = hidden_states.shape
|
176 |
+
|
177 |
+
# 1. Time embedding
|
178 |
+
timesteps = timestep
|
179 |
+
t_emb = self.time_proj(timesteps)
|
180 |
+
|
181 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
182 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
183 |
+
# there might be better ways to encapsulate this.
|
184 |
+
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
185 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
186 |
+
|
187 |
+
# 2. Patch embedding
|
188 |
+
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
189 |
+
hidden_states = self.embedding_dropout(hidden_states)
|
190 |
+
|
191 |
+
# Process tracking maps
|
192 |
+
prompt_embed = encoder_hidden_states.clone()
|
193 |
+
tracking_maps_hidden_states = self.patch_embed(prompt_embed, tracking_maps)
|
194 |
+
tracking_maps_hidden_states = self.embedding_dropout(tracking_maps_hidden_states)
|
195 |
+
del prompt_embed
|
196 |
+
|
197 |
+
text_seq_length = encoder_hidden_states.shape[1]
|
198 |
+
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
199 |
+
hidden_states = hidden_states[:, text_seq_length:]
|
200 |
+
tracking_maps = tracking_maps_hidden_states[:, text_seq_length:]
|
201 |
+
|
202 |
+
# Combine hidden states and tracking maps initially
|
203 |
+
combined = hidden_states + tracking_maps
|
204 |
+
tracking_maps = self.initial_combine_linear(combined)
|
205 |
+
|
206 |
+
# Process transformer blocks
|
207 |
+
for i in range(len(self.transformer_blocks)):
|
208 |
+
if self.training and self.gradient_checkpointing:
|
209 |
+
# Gradient checkpointing logic for hidden states
|
210 |
+
def create_custom_forward(module):
|
211 |
+
def custom_forward(*inputs):
|
212 |
+
return module(*inputs)
|
213 |
+
return custom_forward
|
214 |
+
|
215 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
216 |
+
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
217 |
+
create_custom_forward(self.transformer_blocks[i]),
|
218 |
+
hidden_states,
|
219 |
+
encoder_hidden_states,
|
220 |
+
emb,
|
221 |
+
image_rotary_emb,
|
222 |
+
**ckpt_kwargs,
|
223 |
+
)
|
224 |
+
else:
|
225 |
+
hidden_states, encoder_hidden_states = self.transformer_blocks[i](
|
226 |
+
hidden_states=hidden_states,
|
227 |
+
encoder_hidden_states=encoder_hidden_states,
|
228 |
+
temb=emb,
|
229 |
+
image_rotary_emb=image_rotary_emb,
|
230 |
+
)
|
231 |
+
|
232 |
+
if i < len(self.transformer_blocks_copy):
|
233 |
+
if self.training and self.gradient_checkpointing:
|
234 |
+
# Gradient checkpointing logic for tracking maps
|
235 |
+
tracking_maps, _ = torch.utils.checkpoint.checkpoint(
|
236 |
+
create_custom_forward(self.transformer_blocks_copy[i]),
|
237 |
+
tracking_maps,
|
238 |
+
encoder_hidden_states,
|
239 |
+
emb,
|
240 |
+
image_rotary_emb,
|
241 |
+
**ckpt_kwargs,
|
242 |
+
)
|
243 |
+
else:
|
244 |
+
tracking_maps, _ = self.transformer_blocks_copy[i](
|
245 |
+
hidden_states=tracking_maps,
|
246 |
+
encoder_hidden_states=encoder_hidden_states,
|
247 |
+
temb=emb,
|
248 |
+
image_rotary_emb=image_rotary_emb,
|
249 |
+
)
|
250 |
+
|
251 |
+
# Combine hidden states and tracking maps
|
252 |
+
tracking_maps = self.combine_linears[i](tracking_maps)
|
253 |
+
hidden_states = hidden_states + tracking_maps
|
254 |
+
|
255 |
+
|
256 |
+
if not self.config.use_rotary_positional_embeddings:
|
257 |
+
# CogVideoX-2B
|
258 |
+
hidden_states = self.norm_final(hidden_states)
|
259 |
+
else:
|
260 |
+
# CogVideoX-5B
|
261 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
262 |
+
hidden_states = self.norm_final(hidden_states)
|
263 |
+
hidden_states = hidden_states[:, text_seq_length:]
|
264 |
+
|
265 |
+
# 4. Final block
|
266 |
+
hidden_states = self.norm_out(hidden_states, temb=emb)
|
267 |
+
hidden_states = self.proj_out(hidden_states)
|
268 |
+
|
269 |
+
# 5. Unpatchify
|
270 |
+
# Note: we use `-1` instead of `channels`:
|
271 |
+
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
|
272 |
+
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
|
273 |
+
p = self.config.patch_size
|
274 |
+
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
275 |
+
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
276 |
+
|
277 |
+
if USE_PEFT_BACKEND:
|
278 |
+
# remove `lora_scale` from each PEFT layer
|
279 |
+
unscale_lora_layers(self, lora_scale)
|
280 |
+
|
281 |
+
if not return_dict:
|
282 |
+
return (output,)
|
283 |
+
return Transformer2DModelOutput(sample=output)
|
284 |
+
|
285 |
+
@classmethod
|
286 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
287 |
+
try:
|
288 |
+
model = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
|
289 |
+
print("Loaded DiffusionAsShader checkpoint directly.")
|
290 |
+
|
291 |
+
for param in model.parameters():
|
292 |
+
param.requires_grad = False
|
293 |
+
|
294 |
+
for linear in model.combine_linears:
|
295 |
+
for param in linear.parameters():
|
296 |
+
param.requires_grad = True
|
297 |
+
|
298 |
+
for block in model.transformer_blocks_copy:
|
299 |
+
for param in block.parameters():
|
300 |
+
param.requires_grad = True
|
301 |
+
|
302 |
+
for param in model.initial_combine_linear.parameters():
|
303 |
+
param.requires_grad = True
|
304 |
+
|
305 |
+
return model
|
306 |
+
|
307 |
+
except Exception as e:
|
308 |
+
print(f"Failed to load as DiffusionAsShader: {e}")
|
309 |
+
print("Attempting to load as CogVideoXTransformer3DModel and convert...")
|
310 |
+
|
311 |
+
base_model = CogVideoXTransformer3DModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
312 |
+
|
313 |
+
config = dict(base_model.config)
|
314 |
+
config["num_tracking_blocks"] = kwargs.pop("num_tracking_blocks", 18)
|
315 |
+
|
316 |
+
model = cls(**config)
|
317 |
+
model.load_state_dict(base_model.state_dict(), strict=False)
|
318 |
+
|
319 |
+
model.initial_combine_linear.weight.data.zero_()
|
320 |
+
model.initial_combine_linear.bias.data.zero_()
|
321 |
+
|
322 |
+
for linear in model.combine_linears:
|
323 |
+
linear.weight.data.zero_()
|
324 |
+
linear.bias.data.zero_()
|
325 |
+
|
326 |
+
for i in range(model.num_tracking_blocks):
|
327 |
+
model.transformer_blocks_copy[i].load_state_dict(model.transformer_blocks[i].state_dict())
|
328 |
+
|
329 |
+
|
330 |
+
for param in model.parameters():
|
331 |
+
param.requires_grad = False
|
332 |
+
|
333 |
+
for linear in model.combine_linears:
|
334 |
+
for param in linear.parameters():
|
335 |
+
param.requires_grad = True
|
336 |
+
|
337 |
+
for block in model.transformer_blocks_copy:
|
338 |
+
for param in block.parameters():
|
339 |
+
param.requires_grad = True
|
340 |
+
|
341 |
+
for param in model.initial_combine_linear.parameters():
|
342 |
+
param.requires_grad = True
|
343 |
+
|
344 |
+
return model
|
345 |
+
|
346 |
+
def save_pretrained(
|
347 |
+
self,
|
348 |
+
save_directory: Union[str, os.PathLike],
|
349 |
+
is_main_process: bool = True,
|
350 |
+
save_function: Optional[Callable] = None,
|
351 |
+
safe_serialization: bool = True,
|
352 |
+
variant: Optional[str] = None,
|
353 |
+
max_shard_size: Union[int, str] = "5GB",
|
354 |
+
push_to_hub: bool = False,
|
355 |
+
**kwargs,
|
356 |
+
):
|
357 |
+
super().save_pretrained(
|
358 |
+
save_directory,
|
359 |
+
is_main_process=is_main_process,
|
360 |
+
save_function=save_function,
|
361 |
+
safe_serialization=safe_serialization,
|
362 |
+
variant=variant,
|
363 |
+
max_shard_size=max_shard_size,
|
364 |
+
push_to_hub=push_to_hub,
|
365 |
+
**kwargs,
|
366 |
+
)
|
367 |
+
|
368 |
+
if is_main_process:
|
369 |
+
config_dict = dict(self.config)
|
370 |
+
config_dict.pop("_name_or_path", None)
|
371 |
+
config_dict.pop("_use_default_values", None)
|
372 |
+
config_dict["_class_name"] = "CogVideoXTransformer3DModelTracking"
|
373 |
+
config_dict["num_tracking_blocks"] = self.num_tracking_blocks
|
374 |
+
|
375 |
+
os.makedirs(save_directory, exist_ok=True)
|
376 |
+
with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f:
|
377 |
+
import json
|
378 |
+
json.dump(config_dict, f, indent=2)
|
379 |
+
|
380 |
+
class CogVideoXPipelineTracking(CogVideoXPipeline, DiffusionPipeline):
|
381 |
+
|
382 |
+
def __init__(
|
383 |
+
self,
|
384 |
+
tokenizer: T5Tokenizer,
|
385 |
+
text_encoder: T5EncoderModel,
|
386 |
+
vae: AutoencoderKLCogVideoX,
|
387 |
+
transformer: CogVideoXTransformer3DModelTracking,
|
388 |
+
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
389 |
+
):
|
390 |
+
super().__init__(tokenizer, text_encoder, vae, transformer, scheduler)
|
391 |
+
|
392 |
+
if not isinstance(self.transformer, CogVideoXTransformer3DModelTracking):
|
393 |
+
raise ValueError("The transformer in this pipeline must be of type CogVideoXTransformer3DModelTracking")
|
394 |
+
|
395 |
+
@torch.no_grad()
|
396 |
+
def __call__(
|
397 |
+
self,
|
398 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
399 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
400 |
+
height: int = 480,
|
401 |
+
width: int = 720,
|
402 |
+
num_frames: int = 49,
|
403 |
+
num_inference_steps: int = 50,
|
404 |
+
timesteps: Optional[List[int]] = None,
|
405 |
+
guidance_scale: float = 6,
|
406 |
+
use_dynamic_cfg: bool = False,
|
407 |
+
num_videos_per_prompt: int = 1,
|
408 |
+
eta: float = 0.0,
|
409 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
410 |
+
latents: Optional[torch.FloatTensor] = None,
|
411 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
412 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
413 |
+
output_type: str = "pil",
|
414 |
+
return_dict: bool = True,
|
415 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
416 |
+
callback_on_step_end: Optional[
|
417 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
418 |
+
] = None,
|
419 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
420 |
+
max_sequence_length: int = 226,
|
421 |
+
tracking_maps: Optional[torch.Tensor] = None,
|
422 |
+
) -> Union[CogVideoXPipelineOutput, Tuple]:
|
423 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
424 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
425 |
+
|
426 |
+
num_videos_per_prompt = 1
|
427 |
+
|
428 |
+
self.check_inputs(
|
429 |
+
prompt,
|
430 |
+
height,
|
431 |
+
width,
|
432 |
+
negative_prompt,
|
433 |
+
callback_on_step_end_tensor_inputs,
|
434 |
+
prompt_embeds,
|
435 |
+
negative_prompt_embeds,
|
436 |
+
)
|
437 |
+
self._guidance_scale = guidance_scale
|
438 |
+
self._attention_kwargs = attention_kwargs
|
439 |
+
self._interrupt = False
|
440 |
+
|
441 |
+
if prompt is not None and isinstance(prompt, str):
|
442 |
+
batch_size = 1
|
443 |
+
elif prompt is not None and isinstance(prompt, list):
|
444 |
+
batch_size = len(prompt)
|
445 |
+
else:
|
446 |
+
batch_size = prompt_embeds.shape[0]
|
447 |
+
|
448 |
+
device = self._execution_device
|
449 |
+
|
450 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
451 |
+
|
452 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
453 |
+
prompt,
|
454 |
+
negative_prompt,
|
455 |
+
do_classifier_free_guidance,
|
456 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
457 |
+
prompt_embeds=prompt_embeds,
|
458 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
459 |
+
max_sequence_length=max_sequence_length,
|
460 |
+
device=device,
|
461 |
+
)
|
462 |
+
if do_classifier_free_guidance:
|
463 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
464 |
+
|
465 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
466 |
+
self._num_timesteps = len(timesteps)
|
467 |
+
|
468 |
+
latent_channels = self.transformer.config.in_channels
|
469 |
+
latents = self.prepare_latents(
|
470 |
+
batch_size * num_videos_per_prompt,
|
471 |
+
latent_channels,
|
472 |
+
num_frames,
|
473 |
+
height,
|
474 |
+
width,
|
475 |
+
prompt_embeds.dtype,
|
476 |
+
device,
|
477 |
+
generator,
|
478 |
+
latents,
|
479 |
+
)
|
480 |
+
|
481 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
482 |
+
|
483 |
+
image_rotary_emb = (
|
484 |
+
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
485 |
+
if self.transformer.config.use_rotary_positional_embeddings
|
486 |
+
else None
|
487 |
+
)
|
488 |
+
|
489 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
490 |
+
|
491 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
492 |
+
old_pred_original_sample = None
|
493 |
+
for i, t in enumerate(timesteps):
|
494 |
+
if self.interrupt:
|
495 |
+
continue
|
496 |
+
|
497 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
498 |
+
tracking_maps_latent = torch.cat([tracking_maps] * 2) if do_classifier_free_guidance else tracking_maps
|
499 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
500 |
+
|
501 |
+
timestep = t.expand(latent_model_input.shape[0])
|
502 |
+
|
503 |
+
noise_pred = self.transformer(
|
504 |
+
hidden_states=latent_model_input,
|
505 |
+
encoder_hidden_states=prompt_embeds,
|
506 |
+
timestep=timestep,
|
507 |
+
image_rotary_emb=image_rotary_emb,
|
508 |
+
attention_kwargs=attention_kwargs,
|
509 |
+
tracking_maps=tracking_maps_latent,
|
510 |
+
return_dict=False,
|
511 |
+
)[0]
|
512 |
+
noise_pred = noise_pred.float()
|
513 |
+
|
514 |
+
if use_dynamic_cfg:
|
515 |
+
self._guidance_scale = 1 + guidance_scale * (
|
516 |
+
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
517 |
+
)
|
518 |
+
if do_classifier_free_guidance:
|
519 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
520 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
521 |
+
|
522 |
+
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
523 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
524 |
+
else:
|
525 |
+
latents, old_pred_original_sample = self.scheduler.step(
|
526 |
+
noise_pred,
|
527 |
+
old_pred_original_sample,
|
528 |
+
t,
|
529 |
+
timesteps[i - 1] if i > 0 else None,
|
530 |
+
latents,
|
531 |
+
**extra_step_kwargs,
|
532 |
+
return_dict=False,
|
533 |
+
)
|
534 |
+
latents = latents.to(prompt_embeds.dtype)
|
535 |
+
|
536 |
+
if callback_on_step_end is not None:
|
537 |
+
callback_kwargs = {}
|
538 |
+
for k in callback_on_step_end_tensor_inputs:
|
539 |
+
callback_kwargs[k] = locals()[k]
|
540 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
541 |
+
|
542 |
+
latents = callback_outputs.pop("latents", latents)
|
543 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
544 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
545 |
+
|
546 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
547 |
+
progress_bar.update()
|
548 |
+
|
549 |
+
if not output_type == "latent":
|
550 |
+
video = self.decode_latents(latents)
|
551 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
552 |
+
else:
|
553 |
+
video = latents
|
554 |
+
|
555 |
+
self.maybe_free_model_hooks()
|
556 |
+
|
557 |
+
if not return_dict:
|
558 |
+
return (video,)
|
559 |
+
return CogVideoXPipelineOutput(frames=video)
|
560 |
+
|
561 |
+
class CogVideoXImageToVideoPipelineTracking(CogVideoXImageToVideoPipeline, DiffusionPipeline):
|
562 |
+
|
563 |
+
def __init__(
|
564 |
+
self,
|
565 |
+
tokenizer: T5Tokenizer,
|
566 |
+
text_encoder: T5EncoderModel,
|
567 |
+
vae: AutoencoderKLCogVideoX,
|
568 |
+
transformer: CogVideoXTransformer3DModelTracking,
|
569 |
+
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
570 |
+
):
|
571 |
+
super().__init__(tokenizer, text_encoder, vae, transformer, scheduler)
|
572 |
+
|
573 |
+
if not isinstance(self.transformer, CogVideoXTransformer3DModelTracking):
|
574 |
+
raise ValueError("The transformer in this pipeline must be of type CogVideoXTransformer3DModelTracking")
|
575 |
+
|
576 |
+
# 打印transformer blocks的数量
|
577 |
+
print(f"Number of transformer blocks: {len(self.transformer.transformer_blocks)}")
|
578 |
+
print(f"Number of tracking transformer blocks: {len(self.transformer.transformer_blocks_copy)}")
|
579 |
+
self.transformer = torch.compile(self.transformer)
|
580 |
+
|
581 |
+
@torch.no_grad()
|
582 |
+
def __call__(
|
583 |
+
self,
|
584 |
+
image: Union[torch.Tensor, Image.Image],
|
585 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
586 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
587 |
+
height: Optional[int] = None,
|
588 |
+
width: Optional[int] = None,
|
589 |
+
num_frames: int = 49,
|
590 |
+
num_inference_steps: int = 50,
|
591 |
+
timesteps: Optional[List[int]] = None,
|
592 |
+
guidance_scale: float = 6,
|
593 |
+
use_dynamic_cfg: bool = False,
|
594 |
+
num_videos_per_prompt: int = 1,
|
595 |
+
eta: float = 0.0,
|
596 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
597 |
+
latents: Optional[torch.FloatTensor] = None,
|
598 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
599 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
600 |
+
output_type: str = "pil",
|
601 |
+
return_dict: bool = True,
|
602 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
603 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
604 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
605 |
+
max_sequence_length: int = 226,
|
606 |
+
tracking_maps: Optional[torch.Tensor] = None,
|
607 |
+
tracking_image: Optional[torch.Tensor] = None,
|
608 |
+
) -> Union[CogVideoXPipelineOutput, Tuple]:
|
609 |
+
# Most of the implementation remains the same as the parent class
|
610 |
+
# We will modify the parts that need to handle tracking_maps
|
611 |
+
|
612 |
+
# 1. Check inputs and set default values
|
613 |
+
self.check_inputs(
|
614 |
+
image,
|
615 |
+
prompt,
|
616 |
+
height,
|
617 |
+
width,
|
618 |
+
negative_prompt,
|
619 |
+
callback_on_step_end_tensor_inputs,
|
620 |
+
prompt_embeds,
|
621 |
+
negative_prompt_embeds,
|
622 |
+
)
|
623 |
+
self._guidance_scale = guidance_scale
|
624 |
+
self._attention_kwargs = attention_kwargs
|
625 |
+
self._interrupt = False
|
626 |
+
|
627 |
+
if prompt is not None and isinstance(prompt, str):
|
628 |
+
batch_size = 1
|
629 |
+
elif prompt is not None and isinstance(prompt, list):
|
630 |
+
batch_size = len(prompt)
|
631 |
+
else:
|
632 |
+
batch_size = prompt_embeds.shape[0]
|
633 |
+
|
634 |
+
device = self._execution_device
|
635 |
+
|
636 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
637 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
638 |
+
# corresponds to doing no classifier free guidance.
|
639 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
640 |
+
|
641 |
+
# 3. Encode input prompt
|
642 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
643 |
+
prompt=prompt,
|
644 |
+
negative_prompt=negative_prompt,
|
645 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
646 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
647 |
+
prompt_embeds=prompt_embeds,
|
648 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
649 |
+
max_sequence_length=max_sequence_length,
|
650 |
+
device=device,
|
651 |
+
)
|
652 |
+
if do_classifier_free_guidance:
|
653 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
654 |
+
del negative_prompt_embeds
|
655 |
+
|
656 |
+
# 4. Prepare timesteps
|
657 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
658 |
+
self._num_timesteps = len(timesteps)
|
659 |
+
|
660 |
+
# 5. Prepare latents
|
661 |
+
image = self.video_processor.preprocess(image, height=height, width=width).to(
|
662 |
+
device, dtype=prompt_embeds.dtype
|
663 |
+
)
|
664 |
+
|
665 |
+
tracking_image = self.video_processor.preprocess(tracking_image, height=height, width=width).to(
|
666 |
+
device, dtype=prompt_embeds.dtype
|
667 |
+
)
|
668 |
+
if self.transformer.config.in_channels != 16:
|
669 |
+
latent_channels = self.transformer.config.in_channels // 2
|
670 |
+
else:
|
671 |
+
latent_channels = self.transformer.config.in_channels
|
672 |
+
latents, image_latents = self.prepare_latents(
|
673 |
+
image,
|
674 |
+
batch_size * num_videos_per_prompt,
|
675 |
+
latent_channels,
|
676 |
+
num_frames,
|
677 |
+
height,
|
678 |
+
width,
|
679 |
+
prompt_embeds.dtype,
|
680 |
+
device,
|
681 |
+
generator,
|
682 |
+
latents,
|
683 |
+
)
|
684 |
+
del image
|
685 |
+
|
686 |
+
_, tracking_image_latents = self.prepare_latents(
|
687 |
+
tracking_image,
|
688 |
+
batch_size * num_videos_per_prompt,
|
689 |
+
latent_channels,
|
690 |
+
num_frames,
|
691 |
+
height,
|
692 |
+
width,
|
693 |
+
prompt_embeds.dtype,
|
694 |
+
device,
|
695 |
+
generator,
|
696 |
+
latents=None,
|
697 |
+
)
|
698 |
+
del tracking_image
|
699 |
+
|
700 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
701 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
702 |
+
|
703 |
+
# 7. Create rotary embeds if required
|
704 |
+
image_rotary_emb = (
|
705 |
+
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
706 |
+
if self.transformer.config.use_rotary_positional_embeddings
|
707 |
+
else None
|
708 |
+
)
|
709 |
+
|
710 |
+
# 8. Denoising loop
|
711 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
712 |
+
|
713 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
714 |
+
old_pred_original_sample = None
|
715 |
+
for i, t in enumerate(timesteps):
|
716 |
+
if self.interrupt:
|
717 |
+
continue
|
718 |
+
|
719 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
720 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
721 |
+
|
722 |
+
latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
|
723 |
+
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
|
724 |
+
del latent_image_input
|
725 |
+
|
726 |
+
# Handle tracking maps
|
727 |
+
if tracking_maps is not None:
|
728 |
+
latents_tracking_image = torch.cat([tracking_image_latents] * 2) if do_classifier_free_guidance else tracking_image_latents
|
729 |
+
tracking_maps_input = torch.cat([tracking_maps] * 2) if do_classifier_free_guidance else tracking_maps
|
730 |
+
tracking_maps_input = torch.cat([tracking_maps_input, latents_tracking_image], dim=2)
|
731 |
+
del latents_tracking_image
|
732 |
+
else:
|
733 |
+
tracking_maps_input = None
|
734 |
+
|
735 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
736 |
+
timestep = t.expand(latent_model_input.shape[0])
|
737 |
+
|
738 |
+
# Predict noise
|
739 |
+
self.transformer.to(dtype=latent_model_input.dtype)
|
740 |
+
noise_pred = self.transformer(
|
741 |
+
hidden_states=latent_model_input,
|
742 |
+
encoder_hidden_states=prompt_embeds,
|
743 |
+
timestep=timestep,
|
744 |
+
image_rotary_emb=image_rotary_emb,
|
745 |
+
attention_kwargs=attention_kwargs,
|
746 |
+
tracking_maps=tracking_maps_input,
|
747 |
+
return_dict=False,
|
748 |
+
)[0]
|
749 |
+
del latent_model_input
|
750 |
+
if tracking_maps_input is not None:
|
751 |
+
del tracking_maps_input
|
752 |
+
noise_pred = noise_pred.float()
|
753 |
+
|
754 |
+
# perform guidance
|
755 |
+
if use_dynamic_cfg:
|
756 |
+
self._guidance_scale = 1 + guidance_scale * (
|
757 |
+
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
758 |
+
)
|
759 |
+
if do_classifier_free_guidance:
|
760 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
761 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
762 |
+
del noise_pred_uncond, noise_pred_text
|
763 |
+
|
764 |
+
# compute the previous noisy sample x_t -> x_t-1
|
765 |
+
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
766 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
767 |
+
else:
|
768 |
+
latents, old_pred_original_sample = self.scheduler.step(
|
769 |
+
noise_pred,
|
770 |
+
old_pred_original_sample,
|
771 |
+
t,
|
772 |
+
timesteps[i - 1] if i > 0 else None,
|
773 |
+
latents,
|
774 |
+
**extra_step_kwargs,
|
775 |
+
return_dict=False,
|
776 |
+
)
|
777 |
+
del noise_pred
|
778 |
+
latents = latents.to(prompt_embeds.dtype)
|
779 |
+
|
780 |
+
# call the callback, if provided
|
781 |
+
if callback_on_step_end is not None:
|
782 |
+
callback_kwargs = {}
|
783 |
+
for k in callback_on_step_end_tensor_inputs:
|
784 |
+
callback_kwargs[k] = locals()[k]
|
785 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
786 |
+
|
787 |
+
latents = callback_outputs.pop("latents", latents)
|
788 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
789 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
790 |
+
|
791 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
792 |
+
progress_bar.update()
|
793 |
+
|
794 |
+
# 9. Post-processing
|
795 |
+
if not output_type == "latent":
|
796 |
+
video = self.decode_latents(latents)
|
797 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
798 |
+
else:
|
799 |
+
video = latents
|
800 |
+
|
801 |
+
# Offload all models
|
802 |
+
self.maybe_free_model_hooks()
|
803 |
+
|
804 |
+
if not return_dict:
|
805 |
+
return (video,)
|
806 |
+
|
807 |
+
return CogVideoXPipelineOutput(frames=video)
|
808 |
+
|
809 |
+
class CogVideoXVideoToVideoPipelineTracking(CogVideoXVideoToVideoPipeline, DiffusionPipeline):
|
810 |
+
|
811 |
+
def __init__(
|
812 |
+
self,
|
813 |
+
tokenizer: T5Tokenizer,
|
814 |
+
text_encoder: T5EncoderModel,
|
815 |
+
vae: AutoencoderKLCogVideoX,
|
816 |
+
transformer: CogVideoXTransformer3DModelTracking,
|
817 |
+
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
818 |
+
):
|
819 |
+
super().__init__(tokenizer, text_encoder, vae, transformer, scheduler)
|
820 |
+
|
821 |
+
if not isinstance(self.transformer, CogVideoXTransformer3DModelTracking):
|
822 |
+
raise ValueError("The transformer in this pipeline must be of type CogVideoXTransformer3DModelTracking")
|
823 |
+
|
824 |
+
@torch.no_grad()
|
825 |
+
def __call__(
|
826 |
+
self,
|
827 |
+
video: List[Image.Image] = None,
|
828 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
829 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
830 |
+
height: int = 480,
|
831 |
+
width: int = 720,
|
832 |
+
num_inference_steps: int = 50,
|
833 |
+
timesteps: Optional[List[int]] = None,
|
834 |
+
strength: float = 0.8,
|
835 |
+
guidance_scale: float = 6,
|
836 |
+
use_dynamic_cfg: bool = False,
|
837 |
+
num_videos_per_prompt: int = 1,
|
838 |
+
eta: float = 0.0,
|
839 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
840 |
+
latents: Optional[torch.FloatTensor] = None,
|
841 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
842 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
843 |
+
output_type: str = "pil",
|
844 |
+
return_dict: bool = True,
|
845 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
846 |
+
callback_on_step_end: Optional[
|
847 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
848 |
+
] = None,
|
849 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
850 |
+
max_sequence_length: int = 226,
|
851 |
+
tracking_maps: Optional[torch.Tensor] = None,
|
852 |
+
) -> Union[CogVideoXPipelineOutput, Tuple]:
|
853 |
+
|
854 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
855 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
856 |
+
|
857 |
+
num_videos_per_prompt = 1
|
858 |
+
|
859 |
+
# 1. Check inputs. Raise error if not correct
|
860 |
+
self.check_inputs(
|
861 |
+
prompt=prompt,
|
862 |
+
height=height,
|
863 |
+
width=width,
|
864 |
+
strength=strength,
|
865 |
+
negative_prompt=negative_prompt,
|
866 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
867 |
+
video=video,
|
868 |
+
latents=latents,
|
869 |
+
prompt_embeds=prompt_embeds,
|
870 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
871 |
+
)
|
872 |
+
self._guidance_scale = guidance_scale
|
873 |
+
self._attention_kwargs = attention_kwargs
|
874 |
+
self._interrupt = False
|
875 |
+
|
876 |
+
# 2. Default call parameters
|
877 |
+
if prompt is not None and isinstance(prompt, str):
|
878 |
+
batch_size = 1
|
879 |
+
elif prompt is not None and isinstance(prompt, list):
|
880 |
+
batch_size = len(prompt)
|
881 |
+
else:
|
882 |
+
batch_size = prompt_embeds.shape[0]
|
883 |
+
|
884 |
+
device = self._execution_device
|
885 |
+
|
886 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
887 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
888 |
+
# corresponds to doing no classifier free guidance.
|
889 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
890 |
+
|
891 |
+
# 3. Encode input prompt
|
892 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
893 |
+
prompt,
|
894 |
+
negative_prompt,
|
895 |
+
do_classifier_free_guidance,
|
896 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
897 |
+
prompt_embeds=prompt_embeds,
|
898 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
899 |
+
max_sequence_length=max_sequence_length,
|
900 |
+
device=device,
|
901 |
+
)
|
902 |
+
if do_classifier_free_guidance:
|
903 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
904 |
+
|
905 |
+
# 4. Prepare timesteps
|
906 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
907 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
|
908 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
|
909 |
+
self._num_timesteps = len(timesteps)
|
910 |
+
|
911 |
+
# 5. Prepare latents
|
912 |
+
if latents is None:
|
913 |
+
video = self.video_processor.preprocess_video(video, height=height, width=width)
|
914 |
+
video = video.to(device=device, dtype=prompt_embeds.dtype)
|
915 |
+
|
916 |
+
latent_channels = self.transformer.config.in_channels
|
917 |
+
latents = self.prepare_latents(
|
918 |
+
video,
|
919 |
+
batch_size * num_videos_per_prompt,
|
920 |
+
latent_channels,
|
921 |
+
height,
|
922 |
+
width,
|
923 |
+
prompt_embeds.dtype,
|
924 |
+
device,
|
925 |
+
generator,
|
926 |
+
latents,
|
927 |
+
latent_timestep,
|
928 |
+
)
|
929 |
+
|
930 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
931 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
932 |
+
|
933 |
+
# 7. Create rotary embeds if required
|
934 |
+
image_rotary_emb = (
|
935 |
+
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
936 |
+
if self.transformer.config.use_rotary_positional_embeddings
|
937 |
+
else None
|
938 |
+
)
|
939 |
+
|
940 |
+
# 8. Denoising loop
|
941 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
942 |
+
|
943 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
944 |
+
# for DPM-solver++
|
945 |
+
old_pred_original_sample = None
|
946 |
+
for i, t in enumerate(timesteps):
|
947 |
+
if self.interrupt:
|
948 |
+
continue
|
949 |
+
|
950 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
951 |
+
tracking_maps_input = torch.cat([tracking_maps] * 2) if do_classifier_free_guidance else tracking_maps
|
952 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
953 |
+
|
954 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
955 |
+
timestep = t.expand(latent_model_input.shape[0])
|
956 |
+
|
957 |
+
# predict noise model_output
|
958 |
+
noise_pred = self.transformer(
|
959 |
+
hidden_states=latent_model_input,
|
960 |
+
encoder_hidden_states=prompt_embeds,
|
961 |
+
timestep=timestep,
|
962 |
+
image_rotary_emb=image_rotary_emb,
|
963 |
+
attention_kwargs=attention_kwargs,
|
964 |
+
tracking_maps=tracking_maps_input,
|
965 |
+
return_dict=False,
|
966 |
+
)[0]
|
967 |
+
noise_pred = noise_pred.float()
|
968 |
+
|
969 |
+
# perform guidance
|
970 |
+
if use_dynamic_cfg:
|
971 |
+
self._guidance_scale = 1 + guidance_scale * (
|
972 |
+
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
973 |
+
)
|
974 |
+
if do_classifier_free_guidance:
|
975 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
976 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
977 |
+
|
978 |
+
# compute the previous noisy sample x_t -> x_t-1
|
979 |
+
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
980 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
981 |
+
else:
|
982 |
+
latents, old_pred_original_sample = self.scheduler.step(
|
983 |
+
noise_pred,
|
984 |
+
old_pred_original_sample,
|
985 |
+
t,
|
986 |
+
timesteps[i - 1] if i > 0 else None,
|
987 |
+
latents,
|
988 |
+
**extra_step_kwargs,
|
989 |
+
return_dict=False,
|
990 |
+
)
|
991 |
+
latents = latents.to(prompt_embeds.dtype)
|
992 |
+
|
993 |
+
# call the callback, if provided
|
994 |
+
if callback_on_step_end is not None:
|
995 |
+
callback_kwargs = {}
|
996 |
+
for k in callback_on_step_end_tensor_inputs:
|
997 |
+
callback_kwargs[k] = locals()[k]
|
998 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
999 |
+
|
1000 |
+
latents = callback_outputs.pop("latents", latents)
|
1001 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1002 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
1003 |
+
|
1004 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1005 |
+
progress_bar.update()
|
1006 |
+
|
1007 |
+
if not output_type == "latent":
|
1008 |
+
video = self.decode_latents(latents)
|
1009 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
1010 |
+
else:
|
1011 |
+
video = latents
|
1012 |
+
|
1013 |
+
# Offload all models
|
1014 |
+
self.maybe_free_model_hooks()
|
1015 |
+
|
1016 |
+
if not return_dict:
|
1017 |
+
return (video,)
|
1018 |
+
|
1019 |
+
return CogVideoXPipelineOutput(frames=video)
|
1020 |
+
|
models/pipelines.py
ADDED
@@ -0,0 +1,1040 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import math
|
4 |
+
from tqdm import tqdm
|
5 |
+
from PIL import Image, ImageDraw
|
6 |
+
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
7 |
+
try:
|
8 |
+
sys.path.append(os.path.join(project_root, "submodules/MoGe"))
|
9 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
10 |
+
except:
|
11 |
+
print("Warning: MoGe not found, motion transfer will not be applied")
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import numpy as np
|
15 |
+
from PIL import Image
|
16 |
+
import torchvision.transforms as transforms
|
17 |
+
from diffusers import FluxControlPipeline, CogVideoXDPMScheduler
|
18 |
+
from diffusers.utils import export_to_video, load_image, load_video
|
19 |
+
|
20 |
+
from models.spatracker.predictor import SpaTrackerPredictor
|
21 |
+
from models.spatracker.utils.visualizer import Visualizer
|
22 |
+
from models.cogvideox_tracking import CogVideoXImageToVideoPipelineTracking
|
23 |
+
|
24 |
+
from submodules.MoGe.moge.model import MoGeModel
|
25 |
+
from image_gen_aux import DepthPreprocessor
|
26 |
+
from moviepy.editor import ImageSequenceClip
|
27 |
+
|
28 |
+
class DiffusionAsShaderPipeline:
|
29 |
+
def __init__(self, gpu_id=0, output_dir='outputs'):
|
30 |
+
"""Initialize MotionTransfer class
|
31 |
+
|
32 |
+
Args:
|
33 |
+
gpu_id (int): GPU device ID
|
34 |
+
output_dir (str): Output directory path
|
35 |
+
"""
|
36 |
+
# video parameters
|
37 |
+
self.max_depth = 65.0
|
38 |
+
self.fps = 8
|
39 |
+
|
40 |
+
# camera parameters
|
41 |
+
self.camera_motion=None
|
42 |
+
self.fov=55
|
43 |
+
|
44 |
+
# device
|
45 |
+
self.device = f"cuda:{gpu_id}"
|
46 |
+
torch.cuda.set_device(gpu_id)
|
47 |
+
|
48 |
+
# files
|
49 |
+
self.output_dir = output_dir
|
50 |
+
os.makedirs(output_dir, exist_ok=True)
|
51 |
+
|
52 |
+
# Initialize transform
|
53 |
+
self.transform = transforms.Compose([
|
54 |
+
transforms.Resize((480, 720)),
|
55 |
+
transforms.ToTensor()
|
56 |
+
])
|
57 |
+
|
58 |
+
@torch.no_grad()
|
59 |
+
def _infer(
|
60 |
+
self,
|
61 |
+
prompt: str,
|
62 |
+
model_path: str,
|
63 |
+
tracking_tensor: torch.Tensor = None,
|
64 |
+
image_tensor: torch.Tensor = None, # [C,H,W] in range [0,1]
|
65 |
+
output_path: str = "./output.mp4",
|
66 |
+
num_inference_steps: int = 50,
|
67 |
+
guidance_scale: float = 6.0,
|
68 |
+
num_videos_per_prompt: int = 1,
|
69 |
+
dtype: torch.dtype = torch.bfloat16,
|
70 |
+
fps: int = 24,
|
71 |
+
seed: int = 42,
|
72 |
+
):
|
73 |
+
"""
|
74 |
+
Generates a video based on the given prompt and saves it to the specified path.
|
75 |
+
|
76 |
+
Parameters:
|
77 |
+
- prompt (str): The description of the video to be generated.
|
78 |
+
- model_path (str): The path of the pre-trained model to be used.
|
79 |
+
- tracking_tensor (torch.Tensor): Tracking video tensor [T, C, H, W] in range [0,1]
|
80 |
+
- image_tensor (torch.Tensor): Input image tensor [C, H, W] in range [0,1]
|
81 |
+
- output_path (str): The path where the generated video will be saved.
|
82 |
+
- num_inference_steps (int): Number of steps for the inference process.
|
83 |
+
- guidance_scale (float): The scale for classifier-free guidance.
|
84 |
+
- num_videos_per_prompt (int): Number of videos to generate per prompt.
|
85 |
+
- dtype (torch.dtype): The data type for computation.
|
86 |
+
- seed (int): The seed for reproducibility.
|
87 |
+
"""
|
88 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
89 |
+
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler
|
90 |
+
from models.cogvideox_tracking import CogVideoXTransformer3DModelTracking
|
91 |
+
|
92 |
+
vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
|
93 |
+
text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
|
94 |
+
tokenizer = T5Tokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
95 |
+
transformer = CogVideoXTransformer3DModelTracking.from_pretrained(model_path, subfolder="transformer")
|
96 |
+
scheduler = CogVideoXDDIMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
97 |
+
|
98 |
+
pipe = CogVideoXImageToVideoPipelineTracking(
|
99 |
+
vae=vae,
|
100 |
+
text_encoder=text_encoder,
|
101 |
+
tokenizer=tokenizer,
|
102 |
+
transformer=transformer,
|
103 |
+
scheduler=scheduler
|
104 |
+
)
|
105 |
+
|
106 |
+
# Convert tensor to PIL Image
|
107 |
+
image_np = (image_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
|
108 |
+
image = Image.fromarray(image_np)
|
109 |
+
height, width = image.height, image.width
|
110 |
+
|
111 |
+
pipe.transformer.eval()
|
112 |
+
pipe.text_encoder.eval()
|
113 |
+
pipe.vae.eval()
|
114 |
+
|
115 |
+
# Process tracking tensor
|
116 |
+
tracking_maps = tracking_tensor.float() # [T, C, H, W]
|
117 |
+
tracking_maps = tracking_maps.to(device=self.device, dtype=dtype)
|
118 |
+
tracking_first_frame = tracking_maps[0:1] # Get first frame as [1, C, H, W]
|
119 |
+
height, width = tracking_first_frame.shape[2], tracking_first_frame.shape[3]
|
120 |
+
|
121 |
+
# 2. Set Scheduler.
|
122 |
+
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
123 |
+
|
124 |
+
pipe.to(self.device, dtype=dtype)
|
125 |
+
# pipe.enable_sequential_cpu_offload()
|
126 |
+
|
127 |
+
pipe.vae.enable_slicing()
|
128 |
+
pipe.vae.enable_tiling()
|
129 |
+
pipe.transformer.eval()
|
130 |
+
pipe.text_encoder.eval()
|
131 |
+
pipe.vae.eval()
|
132 |
+
|
133 |
+
pipe.transformer.gradient_checkpointing = False
|
134 |
+
|
135 |
+
print("Encoding tracking maps")
|
136 |
+
tracking_maps = tracking_maps.unsqueeze(0) # [B, T, C, H, W]
|
137 |
+
tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) # [B, C, T, H, W]
|
138 |
+
tracking_latent_dist = pipe.vae.encode(tracking_maps).latent_dist
|
139 |
+
tracking_maps = tracking_latent_dist.sample() * pipe.vae.config.scaling_factor
|
140 |
+
tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
|
141 |
+
|
142 |
+
# 4. Generate the video frames based on the prompt.
|
143 |
+
video_generate = pipe(
|
144 |
+
prompt=prompt,
|
145 |
+
negative_prompt="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion.",
|
146 |
+
image=image,
|
147 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
148 |
+
num_inference_steps=num_inference_steps,
|
149 |
+
num_frames=49,
|
150 |
+
use_dynamic_cfg=True,
|
151 |
+
guidance_scale=guidance_scale,
|
152 |
+
generator=torch.Generator().manual_seed(seed),
|
153 |
+
tracking_maps=tracking_maps,
|
154 |
+
tracking_image=tracking_first_frame,
|
155 |
+
height=height,
|
156 |
+
width=width,
|
157 |
+
).frames[0]
|
158 |
+
|
159 |
+
# 5. Export the generated frames to a video file. fps must be 8 for original video.
|
160 |
+
output_path = output_path if output_path else f"result.mp4"
|
161 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
162 |
+
export_to_video(video_generate, output_path, fps=fps)
|
163 |
+
|
164 |
+
#========== camera parameters ==========#
|
165 |
+
|
166 |
+
def _set_camera_motion(self, camera_motion):
|
167 |
+
self.camera_motion = camera_motion
|
168 |
+
|
169 |
+
def _get_intr(self, fov, H=480, W=720):
|
170 |
+
fov_rad = math.radians(fov)
|
171 |
+
focal_length = (W / 2) / math.tan(fov_rad / 2)
|
172 |
+
|
173 |
+
cx = W / 2
|
174 |
+
cy = H / 2
|
175 |
+
|
176 |
+
intr = torch.tensor([
|
177 |
+
[focal_length, 0, cx],
|
178 |
+
[0, focal_length, cy],
|
179 |
+
[0, 0, 1]
|
180 |
+
], dtype=torch.float32)
|
181 |
+
|
182 |
+
return intr
|
183 |
+
|
184 |
+
def _apply_poses(self, pts, intr, poses):
|
185 |
+
"""
|
186 |
+
Args:
|
187 |
+
pts (torch.Tensor): pointclouds coordinates [T, N, 3]
|
188 |
+
intr (torch.Tensor): camera intrinsics [T, 3, 3]
|
189 |
+
poses (numpy.ndarray): camera poses [T, 4, 4]
|
190 |
+
"""
|
191 |
+
poses = torch.from_numpy(poses).float().to(self.device)
|
192 |
+
|
193 |
+
T, N, _ = pts.shape
|
194 |
+
ones = torch.ones(T, N, 1, device=self.device, dtype=torch.float)
|
195 |
+
pts_hom = torch.cat([pts[:, :, :2], ones], dim=-1) # (T, N, 3)
|
196 |
+
pts_cam = torch.bmm(pts_hom, torch.linalg.inv(intr).transpose(1, 2)) # (T, N, 3)
|
197 |
+
pts_cam[:,:, :3] /= pts[:, :, 2:3]
|
198 |
+
|
199 |
+
# to homogeneous
|
200 |
+
pts_cam = torch.cat([pts_cam, ones], dim=-1) # (T, N, 4)
|
201 |
+
|
202 |
+
if poses.shape[0] == 1:
|
203 |
+
poses = poses.repeat(T, 1, 1)
|
204 |
+
elif poses.shape[0] != T:
|
205 |
+
raise ValueError(f"Poses length ({poses.shape[0]}) must match sequence length ({T})")
|
206 |
+
|
207 |
+
pts_world = torch.bmm(pts_cam, poses.transpose(1, 2))[:, :, :3] # (T, N, 3)
|
208 |
+
|
209 |
+
pts_proj = torch.bmm(pts_world, intr.transpose(1, 2)) # (T, N, 3)
|
210 |
+
pts_proj[:, :, :2] /= pts_proj[:, :, 2:3]
|
211 |
+
|
212 |
+
return pts_proj
|
213 |
+
|
214 |
+
def apply_traj_on_tracking(self, pred_tracks, camera_motion=None, fov=55, frame_num=49):
|
215 |
+
intr = self._get_intr(fov).unsqueeze(0).repeat(frame_num, 1, 1).to(self.device)
|
216 |
+
tracking_pts = self._apply_poses(pred_tracks.squeeze(), intr, camera_motion).unsqueeze(0)
|
217 |
+
return tracking_pts
|
218 |
+
|
219 |
+
##============= SpatialTracker =============##
|
220 |
+
|
221 |
+
def generate_tracking_spatracker(self, video_tensor, density=70):
|
222 |
+
"""Generate tracking video
|
223 |
+
|
224 |
+
Args:
|
225 |
+
video_tensor (torch.Tensor): Input video tensor
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
str: Path to tracking video
|
229 |
+
"""
|
230 |
+
print("Loading tracking models...")
|
231 |
+
# Load tracking model
|
232 |
+
tracker = SpaTrackerPredictor(
|
233 |
+
checkpoint=os.path.join(project_root, 'checkpoints/spatracker/spaT_final.pth'),
|
234 |
+
interp_shape=(384, 576),
|
235 |
+
seq_length=12
|
236 |
+
).to(self.device)
|
237 |
+
|
238 |
+
# Load depth model
|
239 |
+
self.depth_preprocessor = DepthPreprocessor.from_pretrained("Intel/zoedepth-nyu-kitti")
|
240 |
+
self.depth_preprocessor.to(self.device)
|
241 |
+
|
242 |
+
try:
|
243 |
+
video = video_tensor.unsqueeze(0).to(self.device)
|
244 |
+
|
245 |
+
video_depths = []
|
246 |
+
for i in range(video_tensor.shape[0]):
|
247 |
+
frame = (video_tensor[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
248 |
+
depth = self.depth_preprocessor(Image.fromarray(frame))[0]
|
249 |
+
depth_tensor = transforms.ToTensor()(depth) # [1, H, W]
|
250 |
+
video_depths.append(depth_tensor)
|
251 |
+
video_depth = torch.stack(video_depths, dim=0).to(self.device)
|
252 |
+
# print("Video depth shape:", video_depth.shape)
|
253 |
+
|
254 |
+
segm_mask = np.ones((480, 720), dtype=np.uint8)
|
255 |
+
|
256 |
+
pred_tracks, pred_visibility, T_Firsts = tracker(
|
257 |
+
video * 255,
|
258 |
+
video_depth=video_depth,
|
259 |
+
grid_size=density,
|
260 |
+
backward_tracking=False,
|
261 |
+
depth_predictor=None,
|
262 |
+
grid_query_frame=0,
|
263 |
+
segm_mask=torch.from_numpy(segm_mask)[None, None].to(self.device),
|
264 |
+
wind_length=12,
|
265 |
+
progressive_tracking=False
|
266 |
+
)
|
267 |
+
|
268 |
+
return pred_tracks, pred_visibility, T_Firsts
|
269 |
+
|
270 |
+
finally:
|
271 |
+
# Clean up GPU memory
|
272 |
+
del tracker, self.depth_preprocessor
|
273 |
+
torch.cuda.empty_cache()
|
274 |
+
|
275 |
+
def visualize_tracking_spatracker(self, video, pred_tracks, pred_visibility, T_Firsts, save_tracking=True):
|
276 |
+
video = video.unsqueeze(0).to(self.device)
|
277 |
+
vis = Visualizer(save_dir=self.output_dir, grayscale=False, fps=24, pad_value=0)
|
278 |
+
msk_query = (T_Firsts == 0)
|
279 |
+
pred_tracks = pred_tracks[:,:,msk_query.squeeze()]
|
280 |
+
pred_visibility = pred_visibility[:,:,msk_query.squeeze()]
|
281 |
+
|
282 |
+
tracking_video = vis.visualize(video=video, tracks=pred_tracks,
|
283 |
+
visibility=pred_visibility, save_video=False,
|
284 |
+
filename="temp")
|
285 |
+
|
286 |
+
tracking_video = tracking_video.squeeze(0) # [T, C, H, W]
|
287 |
+
wide_list = list(tracking_video.unbind(0))
|
288 |
+
wide_list = [wide.permute(1, 2, 0).cpu().numpy() for wide in wide_list]
|
289 |
+
clip = ImageSequenceClip(wide_list, fps=self.fps)
|
290 |
+
|
291 |
+
tracking_path = None
|
292 |
+
if save_tracking:
|
293 |
+
try:
|
294 |
+
tracking_path = os.path.join(self.output_dir, "tracking_video.mp4")
|
295 |
+
clip.write_videofile(tracking_path, codec="libx264", fps=self.fps, logger=None)
|
296 |
+
print(f"Video saved to {tracking_path}")
|
297 |
+
except Exception as e:
|
298 |
+
print(f"Warning: Failed to save tracking video: {e}")
|
299 |
+
tracking_path = None
|
300 |
+
|
301 |
+
# Convert tracking_video back to tensor in range [0,1]
|
302 |
+
tracking_frames = np.array(list(clip.iter_frames())) / 255.0
|
303 |
+
tracking_video = torch.from_numpy(tracking_frames).permute(0, 3, 1, 2).float()
|
304 |
+
|
305 |
+
return tracking_path, tracking_video
|
306 |
+
|
307 |
+
##============= MoGe =============##
|
308 |
+
|
309 |
+
def valid_mask(self, pixels, W, H):
|
310 |
+
"""Check if pixels are within valid image bounds
|
311 |
+
|
312 |
+
Args:
|
313 |
+
pixels (numpy.ndarray): Pixel coordinates of shape [N, 2]
|
314 |
+
W (int): Image width
|
315 |
+
H (int): Image height
|
316 |
+
|
317 |
+
Returns:
|
318 |
+
numpy.ndarray: Boolean mask of valid pixels
|
319 |
+
"""
|
320 |
+
return ((pixels[:, 0] >= 0) & (pixels[:, 0] < W) & (pixels[:, 1] > 0) & \
|
321 |
+
(pixels[:, 1] < H))
|
322 |
+
|
323 |
+
def sort_points_by_depth(self, points, depths):
|
324 |
+
"""Sort points by depth values
|
325 |
+
|
326 |
+
Args:
|
327 |
+
points (numpy.ndarray): Points array of shape [N, 2]
|
328 |
+
depths (numpy.ndarray): Depth values of shape [N]
|
329 |
+
|
330 |
+
Returns:
|
331 |
+
tuple: (sorted_points, sorted_depths, sort_index)
|
332 |
+
"""
|
333 |
+
# Combine points and depths into a single array for sorting
|
334 |
+
combined = np.hstack((points, depths[:, None])) # Nx3 (points + depth)
|
335 |
+
# Sort by depth (last column) in descending order
|
336 |
+
sort_index = combined[:, -1].argsort()[::-1]
|
337 |
+
sorted_combined = combined[sort_index]
|
338 |
+
# Split back into points and depths
|
339 |
+
sorted_points = sorted_combined[:, :-1]
|
340 |
+
sorted_depths = sorted_combined[:, -1]
|
341 |
+
return sorted_points, sorted_depths, sort_index
|
342 |
+
|
343 |
+
def draw_rectangle(self, rgb, coord, side_length, color=(255, 0, 0)):
|
344 |
+
"""Draw a rectangle on the image
|
345 |
+
|
346 |
+
Args:
|
347 |
+
rgb (PIL.Image): Image to draw on
|
348 |
+
coord (tuple): Center coordinates (x, y)
|
349 |
+
side_length (int): Length of rectangle sides
|
350 |
+
color (tuple): RGB color tuple
|
351 |
+
"""
|
352 |
+
draw = ImageDraw.Draw(rgb)
|
353 |
+
# Calculate the bounding box of the rectangle
|
354 |
+
left_up_point = (coord[0] - side_length//2, coord[1] - side_length//2)
|
355 |
+
right_down_point = (coord[0] + side_length//2, coord[1] + side_length//2)
|
356 |
+
color = tuple(list(color))
|
357 |
+
|
358 |
+
draw.rectangle(
|
359 |
+
[left_up_point, right_down_point],
|
360 |
+
fill=tuple(color),
|
361 |
+
outline=tuple(color),
|
362 |
+
)
|
363 |
+
|
364 |
+
def visualize_tracking_moge(self, points, mask, save_tracking=True):
|
365 |
+
"""Visualize tracking results from MoGe model
|
366 |
+
|
367 |
+
Args:
|
368 |
+
points (numpy.ndarray): Points array of shape [T, H, W, 3]
|
369 |
+
mask (numpy.ndarray): Binary mask of shape [H, W]
|
370 |
+
save_tracking (bool): Whether to save tracking video
|
371 |
+
|
372 |
+
Returns:
|
373 |
+
tuple: (tracking_path, tracking_video)
|
374 |
+
- tracking_path (str): Path to saved tracking video, None if save_tracking is False
|
375 |
+
- tracking_video (torch.Tensor): Tracking visualization tensor of shape [T, C, H, W] in range [0,1]
|
376 |
+
"""
|
377 |
+
# Create color array
|
378 |
+
T, H, W, _ = points.shape
|
379 |
+
colors = np.zeros((H, W, 3), dtype=np.uint8)
|
380 |
+
|
381 |
+
# Set R channel - based on x coordinates (smaller on the left)
|
382 |
+
colors[:, :, 0] = np.tile(np.linspace(0, 255, W), (H, 1))
|
383 |
+
|
384 |
+
# Set G channel - based on y coordinates (smaller on the top)
|
385 |
+
colors[:, :, 1] = np.tile(np.linspace(0, 255, H), (W, 1)).T
|
386 |
+
|
387 |
+
# Set B channel - based on depth
|
388 |
+
z_values = points[0, :, :, 2] # get z values
|
389 |
+
inv_z = 1 / z_values # calculate 1/z
|
390 |
+
# Calculate 2% and 98% percentiles
|
391 |
+
p2 = np.percentile(inv_z, 2)
|
392 |
+
p98 = np.percentile(inv_z, 98)
|
393 |
+
# Normalize to [0,1] range
|
394 |
+
normalized_z = np.clip((inv_z - p2) / (p98 - p2), 0, 1)
|
395 |
+
colors[:, :, 2] = (normalized_z * 255).astype(np.uint8)
|
396 |
+
colors = colors.astype(np.uint8)
|
397 |
+
# colors = colors * mask[..., None]
|
398 |
+
# points = points * mask[None, :, :, None]
|
399 |
+
|
400 |
+
points = points.reshape(T, -1, 3)
|
401 |
+
colors = colors.reshape(-1, 3)
|
402 |
+
|
403 |
+
# Initialize list to store frames
|
404 |
+
frames = []
|
405 |
+
|
406 |
+
for i, pts_i in enumerate(tqdm(points)):
|
407 |
+
pixels, depths = pts_i[..., :2], pts_i[..., 2]
|
408 |
+
pixels[..., 0] = pixels[..., 0] * W
|
409 |
+
pixels[..., 1] = pixels[..., 1] * H
|
410 |
+
pixels = pixels.astype(int)
|
411 |
+
|
412 |
+
valid = self.valid_mask(pixels, W, H)
|
413 |
+
frame_rgb = colors[valid]
|
414 |
+
pixels = pixels[valid]
|
415 |
+
depths = depths[valid]
|
416 |
+
|
417 |
+
img = Image.fromarray(np.uint8(np.zeros([H, W, 3])), mode="RGB")
|
418 |
+
sorted_pixels, _, sort_index = self.sort_points_by_depth(pixels, depths)
|
419 |
+
step = 1
|
420 |
+
sorted_pixels = sorted_pixels[::step]
|
421 |
+
sorted_rgb = frame_rgb[sort_index][::step]
|
422 |
+
|
423 |
+
for j in range(sorted_pixels.shape[0]):
|
424 |
+
self.draw_rectangle(
|
425 |
+
img,
|
426 |
+
coord=(sorted_pixels[j, 0], sorted_pixels[j, 1]),
|
427 |
+
side_length=2,
|
428 |
+
color=sorted_rgb[j],
|
429 |
+
)
|
430 |
+
frames.append(np.array(img))
|
431 |
+
|
432 |
+
# Convert frames to video tensor in range [0,1]
|
433 |
+
tracking_video = torch.from_numpy(np.stack(frames)).permute(0, 3, 1, 2).float() / 255.0
|
434 |
+
|
435 |
+
tracking_path = None
|
436 |
+
if save_tracking:
|
437 |
+
try:
|
438 |
+
tracking_path = os.path.join(self.output_dir, "tracking_video_moge.mp4")
|
439 |
+
# Convert back to uint8 for saving
|
440 |
+
uint8_frames = [frame.astype(np.uint8) for frame in frames]
|
441 |
+
clip = ImageSequenceClip(uint8_frames, fps=self.fps)
|
442 |
+
clip.write_videofile(tracking_path, codec="libx264", fps=self.fps, logger=None)
|
443 |
+
print(f"Video saved to {tracking_path}")
|
444 |
+
except Exception as e:
|
445 |
+
print(f"Warning: Failed to save tracking video: {e}")
|
446 |
+
tracking_path = None
|
447 |
+
|
448 |
+
return tracking_path, tracking_video
|
449 |
+
|
450 |
+
def apply_tracking(self, video_tensor, fps=8, tracking_tensor=None, img_cond_tensor=None, prompt=None, checkpoint_path=None):
|
451 |
+
"""Generate final video with motion transfer
|
452 |
+
|
453 |
+
Args:
|
454 |
+
video_tensor (torch.Tensor): Input video tensor [T,C,H,W]
|
455 |
+
fps (float): Input video FPS
|
456 |
+
tracking_tensor (torch.Tensor): Tracking video tensor [T,C,H,W]
|
457 |
+
image_tensor (torch.Tensor): First frame tensor [C,H,W] to use for generation
|
458 |
+
prompt (str): Generation prompt
|
459 |
+
checkpoint_path (str): Path to model checkpoint
|
460 |
+
"""
|
461 |
+
self.fps = fps
|
462 |
+
|
463 |
+
# Use first frame if no image provided
|
464 |
+
if img_cond_tensor is None:
|
465 |
+
img_cond_tensor = video_tensor[0]
|
466 |
+
|
467 |
+
# Generate final video
|
468 |
+
final_output = os.path.join(os.path.abspath(self.output_dir), "result.mp4")
|
469 |
+
self._infer(
|
470 |
+
prompt=prompt,
|
471 |
+
model_path=checkpoint_path,
|
472 |
+
tracking_tensor=tracking_tensor,
|
473 |
+
image_tensor=img_cond_tensor,
|
474 |
+
output_path=final_output,
|
475 |
+
num_inference_steps=50,
|
476 |
+
guidance_scale=6.0,
|
477 |
+
dtype=torch.bfloat16,
|
478 |
+
fps=self.fps
|
479 |
+
)
|
480 |
+
print(f"Final video generated successfully at: {final_output}")
|
481 |
+
|
482 |
+
def _set_object_motion(self, motion_type):
|
483 |
+
"""Set object motion type
|
484 |
+
|
485 |
+
Args:
|
486 |
+
motion_type (str): Motion direction ('up', 'down', 'left', 'right')
|
487 |
+
"""
|
488 |
+
self.object_motion = motion_type
|
489 |
+
|
490 |
+
class FirstFrameRepainter:
|
491 |
+
def __init__(self, gpu_id=0, output_dir='outputs'):
|
492 |
+
"""Initialize FirstFrameRepainter
|
493 |
+
|
494 |
+
Args:
|
495 |
+
gpu_id (int): GPU device ID
|
496 |
+
output_dir (str): Output directory path
|
497 |
+
"""
|
498 |
+
self.device = f"cuda:{gpu_id}"
|
499 |
+
self.output_dir = output_dir
|
500 |
+
self.max_depth = 65.0
|
501 |
+
os.makedirs(output_dir, exist_ok=True)
|
502 |
+
|
503 |
+
def repaint(self, image_tensor, prompt, depth_path=None, method="dav"):
|
504 |
+
"""Repaint first frame using Flux
|
505 |
+
|
506 |
+
Args:
|
507 |
+
image_tensor (torch.Tensor): Input image tensor [C,H,W]
|
508 |
+
prompt (str): Repaint prompt
|
509 |
+
depth_path (str): Path to depth image
|
510 |
+
method (str): depth estimator, "moge" or "dav" or "zoedepth"
|
511 |
+
|
512 |
+
Returns:
|
513 |
+
torch.Tensor: Repainted image tensor [C,H,W]
|
514 |
+
"""
|
515 |
+
print("Loading Flux model...")
|
516 |
+
# Load Flux model
|
517 |
+
flux_pipe = FluxControlPipeline.from_pretrained(
|
518 |
+
"black-forest-labs/FLUX.1-Depth-dev",
|
519 |
+
torch_dtype=torch.bfloat16
|
520 |
+
).to(self.device)
|
521 |
+
|
522 |
+
# Get depth map
|
523 |
+
if depth_path is None:
|
524 |
+
if method == "moge":
|
525 |
+
self.moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(self.device)
|
526 |
+
depth_map = self.moge_model.infer(image_tensor.to(self.device))["depth"]
|
527 |
+
depth_map = torch.clamp(depth_map, max=self.max_depth)
|
528 |
+
depth_normalized = 1.0 - (depth_map / self.max_depth)
|
529 |
+
depth_rgb = (depth_normalized * 255).cpu().numpy().astype(np.uint8)
|
530 |
+
control_image = Image.fromarray(depth_rgb).convert("RGB")
|
531 |
+
elif method == "zoedepth":
|
532 |
+
self.depth_preprocessor = DepthPreprocessor.from_pretrained("Intel/zoedepth-nyu-kitti")
|
533 |
+
self.depth_preprocessor.to(self.device)
|
534 |
+
image_np = (image_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
|
535 |
+
control_image = self.depth_preprocessor(Image.fromarray(image_np))[0].convert("RGB")
|
536 |
+
control_image = control_image.point(lambda x: 255 - x) # the zoedepth depth is inverted
|
537 |
+
else:
|
538 |
+
self.depth_preprocessor = DepthPreprocessor.from_pretrained("depth-anything/Depth-Anything-V2-Large-hf")
|
539 |
+
self.depth_preprocessor.to(self.device)
|
540 |
+
image_np = (image_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
|
541 |
+
control_image = self.depth_preprocessor(Image.fromarray(image_np))[0].convert("RGB")
|
542 |
+
else:
|
543 |
+
control_image = Image.open(depth_path).convert("RGB")
|
544 |
+
|
545 |
+
try:
|
546 |
+
repainted_image = flux_pipe(
|
547 |
+
prompt=prompt,
|
548 |
+
control_image=control_image,
|
549 |
+
height=480,
|
550 |
+
width=720,
|
551 |
+
num_inference_steps=30,
|
552 |
+
guidance_scale=7.5,
|
553 |
+
).images[0]
|
554 |
+
|
555 |
+
# Save repainted image
|
556 |
+
repainted_image.save(os.path.join(self.output_dir, "temp_repainted.png"))
|
557 |
+
|
558 |
+
# Convert PIL Image to tensor
|
559 |
+
transform = transforms.Compose([
|
560 |
+
transforms.ToTensor()
|
561 |
+
])
|
562 |
+
repainted_tensor = transform(repainted_image)
|
563 |
+
|
564 |
+
return repainted_tensor
|
565 |
+
|
566 |
+
finally:
|
567 |
+
# Clean up GPU memory
|
568 |
+
del flux_pipe
|
569 |
+
if method == "moge":
|
570 |
+
del self.moge_model
|
571 |
+
else:
|
572 |
+
del self.depth_preprocessor
|
573 |
+
torch.cuda.empty_cache()
|
574 |
+
|
575 |
+
class CameraMotionGenerator:
|
576 |
+
def __init__(self, motion_type, frame_num=49, H=480, W=720, fx=None, fy=None, fov=55, device='cuda'):
|
577 |
+
self.motion_type = motion_type
|
578 |
+
self.frame_num = frame_num
|
579 |
+
self.fov = fov
|
580 |
+
self.device = device
|
581 |
+
self.W = W
|
582 |
+
self.H = H
|
583 |
+
self.intr = torch.tensor([
|
584 |
+
[0, 0, W / 2],
|
585 |
+
[0, 0, H / 2],
|
586 |
+
[0, 0, 1]
|
587 |
+
], dtype=torch.float32, device=device)
|
588 |
+
# if fx, fy not provided
|
589 |
+
if not fx or not fy:
|
590 |
+
fov_rad = math.radians(fov)
|
591 |
+
fx = fy = (W / 2) / math.tan(fov_rad / 2)
|
592 |
+
|
593 |
+
self.intr[0, 0] = fx
|
594 |
+
self.intr[1, 1] = fy
|
595 |
+
|
596 |
+
def _apply_poses(self, pts, poses):
|
597 |
+
"""
|
598 |
+
Args:
|
599 |
+
pts (torch.Tensor): pointclouds coordinates [T, N, 3]
|
600 |
+
intr (torch.Tensor): camera intrinsics [T, 3, 3]
|
601 |
+
poses (numpy.ndarray): camera poses [T, 4, 4]
|
602 |
+
"""
|
603 |
+
if isinstance(poses, np.ndarray):
|
604 |
+
poses = torch.from_numpy(poses)
|
605 |
+
|
606 |
+
intr = self.intr.unsqueeze(0).repeat(self.frame_num, 1, 1).to(torch.float)
|
607 |
+
T, N, _ = pts.shape
|
608 |
+
ones = torch.ones(T, N, 1, device=self.device, dtype=torch.float)
|
609 |
+
pts_hom = torch.cat([pts[:, :, :2], ones], dim=-1) # (T, N, 3)
|
610 |
+
pts_cam = torch.bmm(pts_hom, torch.linalg.inv(intr).transpose(1, 2)) # (T, N, 3)
|
611 |
+
pts_cam[:,:, :3] *= pts[:, :, 2:3]
|
612 |
+
|
613 |
+
# to homogeneous
|
614 |
+
pts_cam = torch.cat([pts_cam, ones], dim=-1) # (T, N, 4)
|
615 |
+
|
616 |
+
if poses.shape[0] == 1:
|
617 |
+
poses = poses.repeat(T, 1, 1)
|
618 |
+
elif poses.shape[0] != T:
|
619 |
+
raise ValueError(f"Poses length ({poses.shape[0]}) must match sequence length ({T})")
|
620 |
+
|
621 |
+
poses = poses.to(torch.float).to(self.device)
|
622 |
+
pts_world = torch.bmm(pts_cam, poses.transpose(1, 2))[:, :, :3] # (T, N, 3)
|
623 |
+
pts_proj = torch.bmm(pts_world, intr.transpose(1, 2)) # (T, N, 3)
|
624 |
+
pts_proj[:, :, :2] /= pts_proj[:, :, 2:3]
|
625 |
+
|
626 |
+
return pts_proj
|
627 |
+
|
628 |
+
def w2s(self, pts, poses):
|
629 |
+
if isinstance(poses, np.ndarray):
|
630 |
+
poses = torch.from_numpy(poses)
|
631 |
+
assert poses.shape[0] == self.frame_num
|
632 |
+
poses = poses.to(torch.float32).to(self.device)
|
633 |
+
T, N, _ = pts.shape # (T, N, 3)
|
634 |
+
intr = self.intr.unsqueeze(0).repeat(self.frame_num, 1, 1)
|
635 |
+
# Step 1: 扩展点的维度,使其变成 (T, N, 4),最后一维填充1 (齐次坐标)
|
636 |
+
ones = torch.ones((T, N, 1), device=self.device, dtype=pts.dtype)
|
637 |
+
points_world_h = torch.cat([pts, ones], dim=-1)
|
638 |
+
points_camera_h = torch.bmm(poses, points_world_h.permute(0, 2, 1))
|
639 |
+
points_camera = points_camera_h[:, :3, :].permute(0, 2, 1)
|
640 |
+
|
641 |
+
points_image_h = torch.bmm(points_camera, intr.permute(0, 2, 1))
|
642 |
+
|
643 |
+
uv = points_image_h[:, :, :2] / points_image_h[:, :, 2:3]
|
644 |
+
|
645 |
+
# Step 5: 提取深度 (Z) 并拼接
|
646 |
+
depth = points_camera[:, :, 2:3] # (T, N, 1)
|
647 |
+
uvd = torch.cat([uv, depth], dim=-1) # (T, N, 3)
|
648 |
+
|
649 |
+
return uvd # 屏幕坐标 + 深度 (T, N, 3)
|
650 |
+
|
651 |
+
def apply_motion_on_pts(self, pts, camera_motion):
|
652 |
+
tracking_pts = self._apply_poses(pts.squeeze(), camera_motion).unsqueeze(0)
|
653 |
+
return tracking_pts
|
654 |
+
|
655 |
+
def set_intr(self, K):
|
656 |
+
if isinstance(K, np.ndarray):
|
657 |
+
K = torch.from_numpy(K)
|
658 |
+
self.intr = K.to(self.device)
|
659 |
+
|
660 |
+
def rot_poses(self, angle, axis='y'):
|
661 |
+
"""Generate a single rotation matrix
|
662 |
+
|
663 |
+
Args:
|
664 |
+
angle (float): Rotation angle in degrees
|
665 |
+
axis (str): Rotation axis ('x', 'y', or 'z')
|
666 |
+
|
667 |
+
Returns:
|
668 |
+
torch.Tensor: Single rotation matrix [4, 4]
|
669 |
+
"""
|
670 |
+
angle_rad = math.radians(angle)
|
671 |
+
cos_theta = torch.cos(torch.tensor(angle_rad))
|
672 |
+
sin_theta = torch.sin(torch.tensor(angle_rad))
|
673 |
+
|
674 |
+
if axis == 'x':
|
675 |
+
rot_mat = torch.tensor([
|
676 |
+
[1, 0, 0, 0],
|
677 |
+
[0, cos_theta, -sin_theta, 0],
|
678 |
+
[0, sin_theta, cos_theta, 0],
|
679 |
+
[0, 0, 0, 1]
|
680 |
+
], dtype=torch.float32)
|
681 |
+
elif axis == 'y':
|
682 |
+
rot_mat = torch.tensor([
|
683 |
+
[cos_theta, 0, sin_theta, 0],
|
684 |
+
[0, 1, 0, 0],
|
685 |
+
[-sin_theta, 0, cos_theta, 0],
|
686 |
+
[0, 0, 0, 1]
|
687 |
+
], dtype=torch.float32)
|
688 |
+
elif axis == 'z':
|
689 |
+
rot_mat = torch.tensor([
|
690 |
+
[cos_theta, -sin_theta, 0, 0],
|
691 |
+
[sin_theta, cos_theta, 0, 0],
|
692 |
+
[0, 0, 1, 0],
|
693 |
+
[0, 0, 0, 1]
|
694 |
+
], dtype=torch.float32)
|
695 |
+
else:
|
696 |
+
raise ValueError("Invalid axis value. Choose 'x', 'y', or 'z'.")
|
697 |
+
|
698 |
+
return rot_mat.to(self.device)
|
699 |
+
|
700 |
+
def trans_poses(self, dx, dy, dz):
|
701 |
+
"""
|
702 |
+
params:
|
703 |
+
- dx: float, displacement along x axis。
|
704 |
+
- dy: float, displacement along y axis。
|
705 |
+
- dz: float, displacement along z axis。
|
706 |
+
|
707 |
+
ret:
|
708 |
+
- matrices: torch.Tensor
|
709 |
+
"""
|
710 |
+
trans_mats = torch.eye(4).unsqueeze(0).repeat(self.frame_num, 1, 1) # (n, 4, 4)
|
711 |
+
|
712 |
+
delta_x = dx / (self.frame_num - 1)
|
713 |
+
delta_y = dy / (self.frame_num - 1)
|
714 |
+
delta_z = dz / (self.frame_num - 1)
|
715 |
+
|
716 |
+
for i in range(self.frame_num):
|
717 |
+
trans_mats[i, 0, 3] = i * delta_x
|
718 |
+
trans_mats[i, 1, 3] = i * delta_y
|
719 |
+
trans_mats[i, 2, 3] = i * delta_z
|
720 |
+
|
721 |
+
return trans_mats.to(self.device)
|
722 |
+
|
723 |
+
|
724 |
+
def _look_at(self, camera_position, target_position):
|
725 |
+
# look at direction
|
726 |
+
direction = target_position - camera_position
|
727 |
+
direction /= np.linalg.norm(direction)
|
728 |
+
# calculate rotation matrix
|
729 |
+
up = np.array([0, 1, 0])
|
730 |
+
right = np.cross(up, direction)
|
731 |
+
right /= np.linalg.norm(right)
|
732 |
+
up = np.cross(direction, right)
|
733 |
+
rotation_matrix = np.vstack([right, up, direction])
|
734 |
+
rotation_matrix = np.linalg.inv(rotation_matrix)
|
735 |
+
return rotation_matrix
|
736 |
+
|
737 |
+
def spiral_poses(self, radius, forward_ratio = 0.5, backward_ratio = 0.5, rotation_times = 0.1, look_at_times = 0.5):
|
738 |
+
"""Generate spiral camera poses
|
739 |
+
|
740 |
+
Args:
|
741 |
+
radius (float): Base radius of the spiral
|
742 |
+
forward_ratio (float): Scale factor for forward motion
|
743 |
+
backward_ratio (float): Scale factor for backward motion
|
744 |
+
rotation_times (float): Number of rotations to complete
|
745 |
+
look_at_times (float): Scale factor for look-at point distance
|
746 |
+
|
747 |
+
Returns:
|
748 |
+
torch.Tensor: Camera poses of shape [num_frames, 4, 4]
|
749 |
+
"""
|
750 |
+
# Generate spiral trajectory
|
751 |
+
t = np.linspace(0, 1, self.frame_num)
|
752 |
+
r = np.sin(np.pi * t) * radius * rotation_times
|
753 |
+
theta = 2 * np.pi * t
|
754 |
+
|
755 |
+
# Calculate camera positions
|
756 |
+
# Limit y motion for better floor/sky view
|
757 |
+
y = r * np.cos(theta) * 0.3
|
758 |
+
x = r * np.sin(theta)
|
759 |
+
z = -r
|
760 |
+
z[z < 0] *= forward_ratio
|
761 |
+
z[z > 0] *= backward_ratio
|
762 |
+
|
763 |
+
# Set look-at target
|
764 |
+
target_pos = np.array([0, 0, radius * look_at_times])
|
765 |
+
cam_pos = np.vstack([x, y, z]).T
|
766 |
+
cam_poses = []
|
767 |
+
|
768 |
+
for pos in cam_pos:
|
769 |
+
rot_mat = self._look_at(pos, target_pos)
|
770 |
+
trans_mat = np.eye(4)
|
771 |
+
trans_mat[:3, :3] = rot_mat
|
772 |
+
trans_mat[:3, 3] = pos
|
773 |
+
cam_poses.append(trans_mat[None])
|
774 |
+
|
775 |
+
camera_poses = np.concatenate(cam_poses, axis=0)
|
776 |
+
return torch.from_numpy(camera_poses).to(self.device)
|
777 |
+
|
778 |
+
def rot(self, pts, angle, axis):
|
779 |
+
"""
|
780 |
+
pts: torch.Tensor, (T, N, 2)
|
781 |
+
"""
|
782 |
+
rot_mats = self.rot_poses(angle, axis)
|
783 |
+
pts = self.apply_motion_on_pts(pts, rot_mats)
|
784 |
+
return pts
|
785 |
+
|
786 |
+
def trans(self, pts, dx, dy, dz):
|
787 |
+
if pts.shape[-1] != 3:
|
788 |
+
raise ValueError("points should be in the 3d coordinate.")
|
789 |
+
trans_mats = self.trans_poses(dx, dy, dz)
|
790 |
+
pts = self.apply_motion_on_pts(pts, trans_mats)
|
791 |
+
return pts
|
792 |
+
|
793 |
+
def spiral(self, pts, radius):
|
794 |
+
spiral_poses = self.spiral_poses(radius)
|
795 |
+
pts = self.apply_motion_on_pts(pts, spiral_poses)
|
796 |
+
return pts
|
797 |
+
|
798 |
+
def get_default_motion(self):
|
799 |
+
"""Parse motion parameters and generate corresponding motion matrices
|
800 |
+
|
801 |
+
Supported formats:
|
802 |
+
- trans <dx> <dy> <dz> [start_frame] [end_frame]: Translation motion
|
803 |
+
- rot <axis> <angle> [start_frame] [end_frame]: Rotation motion
|
804 |
+
- spiral <radius> [start_frame] [end_frame]: Spiral motion
|
805 |
+
|
806 |
+
Multiple transformations can be combined using semicolon (;) as separator:
|
807 |
+
e.g., "trans 0 0 0.5 0 30; rot x 25 0 30; trans 0.1 0 0 30 48"
|
808 |
+
|
809 |
+
Note:
|
810 |
+
- start_frame and end_frame are optional
|
811 |
+
- frame range: 0-49 (will be clamped to this range)
|
812 |
+
- if not specified, defaults to 0-49
|
813 |
+
- frames after end_frame will maintain the final transformation
|
814 |
+
- for combined transformations, they are applied in sequence
|
815 |
+
|
816 |
+
Returns:
|
817 |
+
torch.Tensor: Motion matrices [num_frames, 4, 4]
|
818 |
+
"""
|
819 |
+
if not isinstance(self.motion_type, str):
|
820 |
+
raise ValueError(f'camera_motion must be a string, but got {type(self.motion_type)}')
|
821 |
+
|
822 |
+
# Split combined transformations
|
823 |
+
transform_sequences = [s.strip() for s in self.motion_type.split(';')]
|
824 |
+
|
825 |
+
# Initialize the final motion matrices
|
826 |
+
final_motion = torch.eye(4, device=self.device).unsqueeze(0).repeat(49, 1, 1)
|
827 |
+
|
828 |
+
# Process each transformation in sequence
|
829 |
+
for transform in transform_sequences:
|
830 |
+
params = transform.lower().split()
|
831 |
+
if not params:
|
832 |
+
continue
|
833 |
+
|
834 |
+
motion_type = params[0]
|
835 |
+
|
836 |
+
# Default frame range
|
837 |
+
start_frame = 0
|
838 |
+
end_frame = 48 # 49 frames in total (0-48)
|
839 |
+
|
840 |
+
if motion_type == 'trans':
|
841 |
+
# Parse translation parameters
|
842 |
+
if len(params) not in [4, 6]:
|
843 |
+
raise ValueError(f"trans motion requires 3 or 5 parameters: 'trans <dx> <dy> <dz>' or 'trans <dx> <dy> <dz> <start_frame> <end_frame>', got: {transform}")
|
844 |
+
|
845 |
+
dx, dy, dz = map(float, params[1:4])
|
846 |
+
|
847 |
+
if len(params) == 6:
|
848 |
+
start_frame = max(0, min(48, int(params[4])))
|
849 |
+
end_frame = max(0, min(48, int(params[5])))
|
850 |
+
if start_frame > end_frame:
|
851 |
+
start_frame, end_frame = end_frame, start_frame
|
852 |
+
|
853 |
+
# Generate current transformation
|
854 |
+
current_motion = torch.eye(4, device=self.device).unsqueeze(0).repeat(49, 1, 1)
|
855 |
+
for frame_idx in range(49):
|
856 |
+
if frame_idx < start_frame:
|
857 |
+
continue
|
858 |
+
elif frame_idx <= end_frame:
|
859 |
+
t = (frame_idx - start_frame) / (end_frame - start_frame)
|
860 |
+
current_motion[frame_idx, :3, 3] = torch.tensor([dx, dy, dz], device=self.device) * t
|
861 |
+
else:
|
862 |
+
current_motion[frame_idx] = current_motion[end_frame]
|
863 |
+
|
864 |
+
# Combine with previous transformations
|
865 |
+
final_motion = torch.matmul(final_motion, current_motion)
|
866 |
+
|
867 |
+
elif motion_type == 'rot':
|
868 |
+
# Parse rotation parameters
|
869 |
+
if len(params) not in [3, 5]:
|
870 |
+
raise ValueError(f"rot motion requires 2 or 4 parameters: 'rot <axis> <angle>' or 'rot <axis> <angle> <start_frame> <end_frame>', got: {transform}")
|
871 |
+
|
872 |
+
axis = params[1]
|
873 |
+
if axis not in ['x', 'y', 'z']:
|
874 |
+
raise ValueError(f"Invalid rotation axis '{axis}', must be 'x', 'y' or 'z'")
|
875 |
+
angle = float(params[2])
|
876 |
+
|
877 |
+
if len(params) == 5:
|
878 |
+
start_frame = max(0, min(48, int(params[3])))
|
879 |
+
end_frame = max(0, min(48, int(params[4])))
|
880 |
+
if start_frame > end_frame:
|
881 |
+
start_frame, end_frame = end_frame, start_frame
|
882 |
+
|
883 |
+
current_motion = torch.eye(4, device=self.device).unsqueeze(0).repeat(49, 1, 1)
|
884 |
+
for frame_idx in range(49):
|
885 |
+
if frame_idx < start_frame:
|
886 |
+
continue
|
887 |
+
elif frame_idx <= end_frame:
|
888 |
+
t = (frame_idx - start_frame) / (end_frame - start_frame)
|
889 |
+
current_angle = angle * t
|
890 |
+
current_motion[frame_idx] = self.rot_poses(current_angle, axis)
|
891 |
+
else:
|
892 |
+
current_motion[frame_idx] = current_motion[end_frame]
|
893 |
+
|
894 |
+
# Combine with previous transformations
|
895 |
+
final_motion = torch.matmul(final_motion, current_motion)
|
896 |
+
|
897 |
+
elif motion_type == 'spiral':
|
898 |
+
# Parse spiral motion parameters
|
899 |
+
if len(params) not in [2, 4]:
|
900 |
+
raise ValueError(f"spiral motion requires 1 or 3 parameters: 'spiral <radius>' or 'spiral <radius> <start_frame> <end_frame>', got: {transform}")
|
901 |
+
|
902 |
+
radius = float(params[1])
|
903 |
+
|
904 |
+
if len(params) == 4:
|
905 |
+
start_frame = max(0, min(48, int(params[2])))
|
906 |
+
end_frame = max(0, min(48, int(params[3])))
|
907 |
+
if start_frame > end_frame:
|
908 |
+
start_frame, end_frame = end_frame, start_frame
|
909 |
+
|
910 |
+
current_motion = torch.eye(4, device=self.device).unsqueeze(0).repeat(49, 1, 1)
|
911 |
+
spiral_motion = self.spiral_poses(radius)
|
912 |
+
for frame_idx in range(49):
|
913 |
+
if frame_idx < start_frame:
|
914 |
+
continue
|
915 |
+
elif frame_idx <= end_frame:
|
916 |
+
t = (frame_idx - start_frame) / (end_frame - start_frame)
|
917 |
+
idx = int(t * (len(spiral_motion) - 1))
|
918 |
+
current_motion[frame_idx] = spiral_motion[idx]
|
919 |
+
else:
|
920 |
+
current_motion[frame_idx] = current_motion[end_frame]
|
921 |
+
|
922 |
+
# Combine with previous transformations
|
923 |
+
final_motion = torch.matmul(final_motion, current_motion)
|
924 |
+
|
925 |
+
else:
|
926 |
+
raise ValueError(f'camera_motion type must be in [trans, spiral, rot], but got {motion_type}')
|
927 |
+
|
928 |
+
return final_motion
|
929 |
+
|
930 |
+
class ObjectMotionGenerator:
|
931 |
+
def __init__(self, device="cuda:0"):
|
932 |
+
self.device = device
|
933 |
+
self.num_frames = 49
|
934 |
+
|
935 |
+
def _get_points_in_mask(self, pred_tracks, mask):
|
936 |
+
"""Get points that lie within the mask
|
937 |
+
|
938 |
+
Args:
|
939 |
+
pred_tracks (torch.Tensor): Point trajectories [num_frames, num_points, 3]
|
940 |
+
mask (torch.Tensor): Binary mask [H, W]
|
941 |
+
|
942 |
+
Returns:
|
943 |
+
torch.Tensor: Boolean mask for selected points [num_points]
|
944 |
+
"""
|
945 |
+
first_frame_points = pred_tracks[0] # [num_points, 3]
|
946 |
+
xy_points = first_frame_points[:, :2] # [num_points, 2]
|
947 |
+
|
948 |
+
xy_pixels = xy_points.round().long()
|
949 |
+
xy_pixels[:, 0].clamp_(0, mask.shape[1] - 1)
|
950 |
+
xy_pixels[:, 1].clamp_(0, mask.shape[0] - 1)
|
951 |
+
|
952 |
+
points_in_mask = mask[xy_pixels[:, 1], xy_pixels[:, 0]]
|
953 |
+
|
954 |
+
return points_in_mask
|
955 |
+
|
956 |
+
def apply_motion(self, pred_tracks, mask, motion_type, distance, num_frames=49, tracking_method="spatracker"):
|
957 |
+
|
958 |
+
self.num_frames = num_frames
|
959 |
+
pred_tracks = pred_tracks.to(self.device).float()
|
960 |
+
mask = mask.to(self.device)
|
961 |
+
|
962 |
+
template = {
|
963 |
+
'up': ('trans', torch.tensor([0, -1, 0])),
|
964 |
+
'down': ('trans', torch.tensor([0, 1, 0])),
|
965 |
+
'left': ('trans', torch.tensor([-1, 0, 0])),
|
966 |
+
'right': ('trans', torch.tensor([1, 0, 0])),
|
967 |
+
'front': ('trans', torch.tensor([0, 0, 1])),
|
968 |
+
'back': ('trans', torch.tensor([0, 0, -1])),
|
969 |
+
'rot': ('rot', None) # rotate around y axis
|
970 |
+
}
|
971 |
+
|
972 |
+
if motion_type not in template:
|
973 |
+
raise ValueError(f"unknown motion type: {motion_type}")
|
974 |
+
|
975 |
+
motion_type, base_vec = template[motion_type]
|
976 |
+
if base_vec is not None:
|
977 |
+
base_vec = base_vec.to(self.device) * distance
|
978 |
+
|
979 |
+
if tracking_method == "moge":
|
980 |
+
T, H, W, _ = pred_tracks.shape
|
981 |
+
valid_selected = ~torch.any(torch.isnan(pred_tracks[0]), dim=2) & mask
|
982 |
+
points = pred_tracks[0][valid_selected].reshape(-1, 3)
|
983 |
+
else:
|
984 |
+
points_in_mask = self._get_points_in_mask(pred_tracks, mask)
|
985 |
+
points = pred_tracks[0, points_in_mask]
|
986 |
+
|
987 |
+
center = points.mean(dim=0)
|
988 |
+
|
989 |
+
motions = []
|
990 |
+
for frame_idx in range(num_frames):
|
991 |
+
t = frame_idx / (num_frames - 1)
|
992 |
+
current_motion = torch.eye(4, device=self.device)
|
993 |
+
current_motion[:3, 3] = -center
|
994 |
+
motion_mat = torch.eye(4, device=self.device)
|
995 |
+
if motion_type == 'trans':
|
996 |
+
motion_mat[:3, 3] = base_vec * t
|
997 |
+
else: # 'rot'
|
998 |
+
angle_rad = torch.deg2rad(torch.tensor(distance * t, device=self.device))
|
999 |
+
cos_t = torch.cos(angle_rad)
|
1000 |
+
sin_t = torch.sin(angle_rad)
|
1001 |
+
motion_mat[0, 0] = cos_t
|
1002 |
+
motion_mat[0, 2] = sin_t
|
1003 |
+
motion_mat[2, 0] = -sin_t
|
1004 |
+
motion_mat[2, 2] = cos_t
|
1005 |
+
|
1006 |
+
current_motion = motion_mat @ current_motion
|
1007 |
+
current_motion[:3, 3] += center
|
1008 |
+
motions.append(current_motion)
|
1009 |
+
|
1010 |
+
motions = torch.stack(motions) # [num_frames, 4, 4]
|
1011 |
+
|
1012 |
+
if tracking_method == "moge":
|
1013 |
+
modified_tracks = pred_tracks.clone().reshape(T, -1, 3)
|
1014 |
+
valid_selected = valid_selected.reshape([-1])
|
1015 |
+
|
1016 |
+
for frame_idx in range(self.num_frames):
|
1017 |
+
motion_mat = motions[frame_idx]
|
1018 |
+
if W > 1:
|
1019 |
+
motion_mat = motion_mat.clone()
|
1020 |
+
motion_mat[0, 3] /= W
|
1021 |
+
motion_mat[1, 3] /= H
|
1022 |
+
points = modified_tracks[frame_idx, valid_selected]
|
1023 |
+
points_homo = torch.cat([points, torch.ones_like(points[:, :1])], dim=1)
|
1024 |
+
transformed_points = torch.matmul(points_homo, motion_mat.T)
|
1025 |
+
modified_tracks[frame_idx, valid_selected] = transformed_points[:, :3]
|
1026 |
+
|
1027 |
+
return modified_tracks.reshape(T, H, W, 3)
|
1028 |
+
|
1029 |
+
else:
|
1030 |
+
points_in_mask = self._get_points_in_mask(pred_tracks, mask)
|
1031 |
+
modified_tracks = pred_tracks.clone()
|
1032 |
+
|
1033 |
+
for frame_idx in range(pred_tracks.shape[0]):
|
1034 |
+
motion_mat = motions[frame_idx]
|
1035 |
+
points = modified_tracks[frame_idx, points_in_mask]
|
1036 |
+
points_homo = torch.cat([points, torch.ones_like(points[:, :1])], dim=1)
|
1037 |
+
transformed_points = torch.matmul(points_homo, motion_mat.T)
|
1038 |
+
modified_tracks[frame_idx, points_in_mask] = transformed_points[:, :3]
|
1039 |
+
|
1040 |
+
return modified_tracks
|
models/spatracker/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
models/spatracker/models/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
models/spatracker/models/build_spatracker.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from models.spatracker.models.core.spatracker.spatracker import SpaTracker
|
10 |
+
|
11 |
+
|
12 |
+
def build_spatracker(
|
13 |
+
checkpoint: str,
|
14 |
+
seq_length: int = 8,
|
15 |
+
):
|
16 |
+
model_name = checkpoint.split("/")[-1].split(".")[0]
|
17 |
+
return build_spatracker_from_cfg(checkpoint=checkpoint, seq_length=seq_length)
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
# model used to produce the results in the paper
|
22 |
+
def build_spatracker_from_cfg(checkpoint=None, seq_length=8):
|
23 |
+
return _build_spatracker(
|
24 |
+
stride=4,
|
25 |
+
sequence_len=seq_length,
|
26 |
+
checkpoint=checkpoint,
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
def _build_spatracker(
|
31 |
+
stride,
|
32 |
+
sequence_len,
|
33 |
+
checkpoint=None,
|
34 |
+
):
|
35 |
+
spatracker = SpaTracker(
|
36 |
+
stride=stride,
|
37 |
+
S=sequence_len,
|
38 |
+
add_space_attn=True,
|
39 |
+
space_depth=6,
|
40 |
+
time_depth=6,
|
41 |
+
)
|
42 |
+
if checkpoint is not None:
|
43 |
+
with open(checkpoint, "rb") as f:
|
44 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
45 |
+
if "model" in state_dict:
|
46 |
+
model_paras = spatracker.state_dict()
|
47 |
+
paras_dict = {k: v for k,v in state_dict["model"].items() if k in spatracker.state_dict()}
|
48 |
+
model_paras.update(paras_dict)
|
49 |
+
state_dict = model_paras
|
50 |
+
spatracker.load_state_dict(state_dict)
|
51 |
+
return spatracker
|
models/spatracker/models/core/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
models/spatracker/models/core/embeddings.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
def get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
11 |
+
"""
|
12 |
+
grid_size: int of the grid height and width
|
13 |
+
return:
|
14 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
15 |
+
"""
|
16 |
+
if isinstance(grid_size, tuple):
|
17 |
+
grid_size_h, grid_size_w = grid_size
|
18 |
+
else:
|
19 |
+
grid_size_h = grid_size_w = grid_size
|
20 |
+
grid_h = np.arange(grid_size_h, dtype=np.float32)
|
21 |
+
grid_w = np.arange(grid_size_w, dtype=np.float32)
|
22 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
23 |
+
grid = np.stack(grid, axis=0)
|
24 |
+
|
25 |
+
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
26 |
+
pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid)
|
27 |
+
if cls_token and extra_tokens > 0:
|
28 |
+
pos_embed = np.concatenate(
|
29 |
+
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
|
30 |
+
)
|
31 |
+
return pos_embed
|
32 |
+
|
33 |
+
|
34 |
+
def get_3d_sincos_pos_embed_from_grid(embed_dim, grid):
|
35 |
+
assert embed_dim % 3 == 0
|
36 |
+
|
37 |
+
# use half of dimensions to encode grid_h
|
38 |
+
B, S, N, _ = grid.shape
|
39 |
+
gridx = grid[..., 0].view(B*S*N).detach().cpu().numpy()
|
40 |
+
gridy = grid[..., 1].view(B*S*N).detach().cpu().numpy()
|
41 |
+
gridz = grid[..., 2].view(B*S*N).detach().cpu().numpy()
|
42 |
+
|
43 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridx) # (N, D/3)
|
44 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridy) # (N, D/3)
|
45 |
+
emb_z = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridz) # (N, D/3)
|
46 |
+
|
47 |
+
|
48 |
+
emb = np.concatenate([emb_h, emb_w, emb_z], axis=1) # (N, D)
|
49 |
+
emb = torch.from_numpy(emb).to(grid.device)
|
50 |
+
return emb.view(B, S, N, embed_dim)
|
51 |
+
|
52 |
+
|
53 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
54 |
+
"""
|
55 |
+
grid_size: int of the grid height and width
|
56 |
+
return:
|
57 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
58 |
+
"""
|
59 |
+
if isinstance(grid_size, tuple):
|
60 |
+
grid_size_h, grid_size_w = grid_size
|
61 |
+
else:
|
62 |
+
grid_size_h = grid_size_w = grid_size
|
63 |
+
grid_h = np.arange(grid_size_h, dtype=np.float32)
|
64 |
+
grid_w = np.arange(grid_size_w, dtype=np.float32)
|
65 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
66 |
+
grid = np.stack(grid, axis=0)
|
67 |
+
|
68 |
+
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
69 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
70 |
+
if cls_token and extra_tokens > 0:
|
71 |
+
pos_embed = np.concatenate(
|
72 |
+
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
|
73 |
+
)
|
74 |
+
return pos_embed
|
75 |
+
|
76 |
+
|
77 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
78 |
+
assert embed_dim % 2 == 0
|
79 |
+
|
80 |
+
# use half of dimensions to encode grid_h
|
81 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
82 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
83 |
+
|
84 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
85 |
+
return emb
|
86 |
+
|
87 |
+
|
88 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
89 |
+
"""
|
90 |
+
embed_dim: output dimension for each position
|
91 |
+
pos: a list of positions to be encoded: size (M,)
|
92 |
+
out: (M, D)
|
93 |
+
"""
|
94 |
+
assert embed_dim % 2 == 0
|
95 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
96 |
+
omega /= embed_dim / 2.0
|
97 |
+
omega = 1.0 / 10000 ** omega # (D/2,)
|
98 |
+
|
99 |
+
pos = pos.reshape(-1) # (M,)
|
100 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
101 |
+
|
102 |
+
emb_sin = np.sin(out) # (M, D/2)
|
103 |
+
emb_cos = np.cos(out) # (M, D/2)
|
104 |
+
|
105 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
106 |
+
return emb
|
107 |
+
|
108 |
+
|
109 |
+
def get_2d_embedding(xy, C, cat_coords=True):
|
110 |
+
B, N, D = xy.shape
|
111 |
+
assert D == 2
|
112 |
+
|
113 |
+
x = xy[:, :, 0:1]
|
114 |
+
y = xy[:, :, 1:2]
|
115 |
+
div_term = (
|
116 |
+
torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
|
117 |
+
).reshape(1, 1, int(C / 2))
|
118 |
+
|
119 |
+
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
120 |
+
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
121 |
+
|
122 |
+
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
123 |
+
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
124 |
+
|
125 |
+
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
126 |
+
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
127 |
+
|
128 |
+
pe = torch.cat([pe_x, pe_y], dim=2) # B, N, C*3
|
129 |
+
if cat_coords:
|
130 |
+
pe = torch.cat([xy, pe], dim=2) # B, N, C*3+3
|
131 |
+
return pe
|
132 |
+
|
133 |
+
|
134 |
+
def get_3d_embedding(xyz, C, cat_coords=True):
|
135 |
+
B, N, D = xyz.shape
|
136 |
+
assert D == 3
|
137 |
+
|
138 |
+
x = xyz[:, :, 0:1]
|
139 |
+
y = xyz[:, :, 1:2]
|
140 |
+
z = xyz[:, :, 2:3]
|
141 |
+
div_term = (
|
142 |
+
torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (1000.0 / C)
|
143 |
+
).reshape(1, 1, int(C / 2))
|
144 |
+
|
145 |
+
pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
|
146 |
+
pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
|
147 |
+
pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
|
148 |
+
|
149 |
+
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
150 |
+
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
151 |
+
|
152 |
+
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
153 |
+
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
154 |
+
|
155 |
+
pe_z[:, :, 0::2] = torch.sin(z * div_term)
|
156 |
+
pe_z[:, :, 1::2] = torch.cos(z * div_term)
|
157 |
+
|
158 |
+
pe = torch.cat([pe_x, pe_y, pe_z], dim=2) # B, N, C*3
|
159 |
+
if cat_coords:
|
160 |
+
pe = torch.cat([pe, xyz], dim=2) # B, N, C*3+3
|
161 |
+
return pe
|
162 |
+
|
163 |
+
|
164 |
+
def get_4d_embedding(xyzw, C, cat_coords=True):
|
165 |
+
B, N, D = xyzw.shape
|
166 |
+
assert D == 4
|
167 |
+
|
168 |
+
x = xyzw[:, :, 0:1]
|
169 |
+
y = xyzw[:, :, 1:2]
|
170 |
+
z = xyzw[:, :, 2:3]
|
171 |
+
w = xyzw[:, :, 3:4]
|
172 |
+
div_term = (
|
173 |
+
torch.arange(0, C, 2, device=xyzw.device, dtype=torch.float32) * (1000.0 / C)
|
174 |
+
).reshape(1, 1, int(C / 2))
|
175 |
+
|
176 |
+
pe_x = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
|
177 |
+
pe_y = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
|
178 |
+
pe_z = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
|
179 |
+
pe_w = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
|
180 |
+
|
181 |
+
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
182 |
+
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
183 |
+
|
184 |
+
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
185 |
+
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
186 |
+
|
187 |
+
pe_z[:, :, 0::2] = torch.sin(z * div_term)
|
188 |
+
pe_z[:, :, 1::2] = torch.cos(z * div_term)
|
189 |
+
|
190 |
+
pe_w[:, :, 0::2] = torch.sin(w * div_term)
|
191 |
+
pe_w[:, :, 1::2] = torch.cos(w * div_term)
|
192 |
+
|
193 |
+
pe = torch.cat([pe_x, pe_y, pe_z, pe_w], dim=2) # B, N, C*3
|
194 |
+
if cat_coords:
|
195 |
+
pe = torch.cat([pe, xyzw], dim=2) # B, N, C*3+3
|
196 |
+
return pe
|
197 |
+
|
198 |
+
import torch.nn as nn
|
199 |
+
class Embedder_Fourier(nn.Module):
|
200 |
+
def __init__(self, input_dim, max_freq_log2, N_freqs,
|
201 |
+
log_sampling=True, include_input=True,
|
202 |
+
periodic_fns=(torch.sin, torch.cos)):
|
203 |
+
'''
|
204 |
+
:param input_dim: dimension of input to be embedded
|
205 |
+
:param max_freq_log2: log2 of max freq; min freq is 1 by default
|
206 |
+
:param N_freqs: number of frequency bands
|
207 |
+
:param log_sampling: if True, frequency bands are linerly sampled in log-space
|
208 |
+
:param include_input: if True, raw input is included in the embedding
|
209 |
+
:param periodic_fns: periodic functions used to embed input
|
210 |
+
'''
|
211 |
+
super(Embedder_Fourier, self).__init__()
|
212 |
+
|
213 |
+
self.input_dim = input_dim
|
214 |
+
self.include_input = include_input
|
215 |
+
self.periodic_fns = periodic_fns
|
216 |
+
|
217 |
+
self.out_dim = 0
|
218 |
+
if self.include_input:
|
219 |
+
self.out_dim += self.input_dim
|
220 |
+
|
221 |
+
self.out_dim += self.input_dim * N_freqs * len(self.periodic_fns)
|
222 |
+
|
223 |
+
if log_sampling:
|
224 |
+
self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs)
|
225 |
+
else:
|
226 |
+
self.freq_bands = torch.linspace(
|
227 |
+
2. ** 0., 2. ** max_freq_log2, N_freqs)
|
228 |
+
|
229 |
+
self.freq_bands = self.freq_bands.numpy().tolist()
|
230 |
+
|
231 |
+
def forward(self,
|
232 |
+
input: torch.Tensor,
|
233 |
+
rescale: float = 1.0):
|
234 |
+
'''
|
235 |
+
:param input: tensor of shape [..., self.input_dim]
|
236 |
+
:return: tensor of shape [..., self.out_dim]
|
237 |
+
'''
|
238 |
+
assert (input.shape[-1] == self.input_dim)
|
239 |
+
out = []
|
240 |
+
if self.include_input:
|
241 |
+
out.append(input/rescale)
|
242 |
+
|
243 |
+
for i in range(len(self.freq_bands)):
|
244 |
+
freq = self.freq_bands[i]
|
245 |
+
for p_fn in self.periodic_fns:
|
246 |
+
out.append(p_fn(input * freq))
|
247 |
+
out = torch.cat(out, dim=-1)
|
248 |
+
|
249 |
+
assert (out.shape[-1] == self.out_dim)
|
250 |
+
return out
|
models/spatracker/models/core/model_utils.py
ADDED
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from easydict import EasyDict as edict
|
10 |
+
from sklearn.decomposition import PCA
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
|
13 |
+
EPS = 1e-6
|
14 |
+
|
15 |
+
def nearest_sample2d(im, x, y, return_inbounds=False):
|
16 |
+
# x and y are each B, N
|
17 |
+
# output is B, C, N
|
18 |
+
if len(im.shape) == 5:
|
19 |
+
B, N, C, H, W = list(im.shape)
|
20 |
+
else:
|
21 |
+
B, C, H, W = list(im.shape)
|
22 |
+
N = list(x.shape)[1]
|
23 |
+
|
24 |
+
x = x.float()
|
25 |
+
y = y.float()
|
26 |
+
H_f = torch.tensor(H, dtype=torch.float32)
|
27 |
+
W_f = torch.tensor(W, dtype=torch.float32)
|
28 |
+
|
29 |
+
# inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()
|
30 |
+
|
31 |
+
max_y = (H_f - 1).int()
|
32 |
+
max_x = (W_f - 1).int()
|
33 |
+
|
34 |
+
x0 = torch.floor(x).int()
|
35 |
+
x1 = x0 + 1
|
36 |
+
y0 = torch.floor(y).int()
|
37 |
+
y1 = y0 + 1
|
38 |
+
|
39 |
+
x0_clip = torch.clamp(x0, 0, max_x)
|
40 |
+
x1_clip = torch.clamp(x1, 0, max_x)
|
41 |
+
y0_clip = torch.clamp(y0, 0, max_y)
|
42 |
+
y1_clip = torch.clamp(y1, 0, max_y)
|
43 |
+
dim2 = W
|
44 |
+
dim1 = W * H
|
45 |
+
|
46 |
+
base = torch.arange(0, B, dtype=torch.int64, device=x.device) * dim1
|
47 |
+
base = torch.reshape(base, [B, 1]).repeat([1, N])
|
48 |
+
|
49 |
+
base_y0 = base + y0_clip * dim2
|
50 |
+
base_y1 = base + y1_clip * dim2
|
51 |
+
|
52 |
+
idx_y0_x0 = base_y0 + x0_clip
|
53 |
+
idx_y0_x1 = base_y0 + x1_clip
|
54 |
+
idx_y1_x0 = base_y1 + x0_clip
|
55 |
+
idx_y1_x1 = base_y1 + x1_clip
|
56 |
+
|
57 |
+
# use the indices to lookup pixels in the flat image
|
58 |
+
# im is B x C x H x W
|
59 |
+
# move C out to last dim
|
60 |
+
if len(im.shape) == 5:
|
61 |
+
im_flat = (im.permute(0, 3, 4, 1, 2)).reshape(B * H * W, N, C)
|
62 |
+
i_y0_x0 = torch.diagonal(im_flat[idx_y0_x0.long()], dim1=1, dim2=2).permute(
|
63 |
+
0, 2, 1
|
64 |
+
)
|
65 |
+
i_y0_x1 = torch.diagonal(im_flat[idx_y0_x1.long()], dim1=1, dim2=2).permute(
|
66 |
+
0, 2, 1
|
67 |
+
)
|
68 |
+
i_y1_x0 = torch.diagonal(im_flat[idx_y1_x0.long()], dim1=1, dim2=2).permute(
|
69 |
+
0, 2, 1
|
70 |
+
)
|
71 |
+
i_y1_x1 = torch.diagonal(im_flat[idx_y1_x1.long()], dim1=1, dim2=2).permute(
|
72 |
+
0, 2, 1
|
73 |
+
)
|
74 |
+
else:
|
75 |
+
im_flat = (im.permute(0, 2, 3, 1)).reshape(B * H * W, C)
|
76 |
+
i_y0_x0 = im_flat[idx_y0_x0.long()]
|
77 |
+
i_y0_x1 = im_flat[idx_y0_x1.long()]
|
78 |
+
i_y1_x0 = im_flat[idx_y1_x0.long()]
|
79 |
+
i_y1_x1 = im_flat[idx_y1_x1.long()]
|
80 |
+
|
81 |
+
# Finally calculate interpolated values.
|
82 |
+
x0_f = x0.float()
|
83 |
+
x1_f = x1.float()
|
84 |
+
y0_f = y0.float()
|
85 |
+
y1_f = y1.float()
|
86 |
+
|
87 |
+
w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
|
88 |
+
w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
|
89 |
+
w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
|
90 |
+
w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)
|
91 |
+
|
92 |
+
# w_yi_xo is B * N * 1
|
93 |
+
max_idx = torch.cat([w_y0_x0, w_y0_x1, w_y1_x0, w_y1_x1], dim=-1).max(dim=-1)[1]
|
94 |
+
output = torch.stack([i_y0_x0, i_y0_x1, i_y1_x0, i_y1_x1], dim=-1).gather(-1, max_idx[...,None,None].repeat(1,1,C,1)).squeeze(-1)
|
95 |
+
|
96 |
+
# output is B*N x C
|
97 |
+
output = output.view(B, -1, C)
|
98 |
+
output = output.permute(0, 2, 1)
|
99 |
+
# output is B x C x N
|
100 |
+
|
101 |
+
if return_inbounds:
|
102 |
+
x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
|
103 |
+
y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
|
104 |
+
inbounds = (x_valid & y_valid).float()
|
105 |
+
inbounds = inbounds.reshape(
|
106 |
+
B, N
|
107 |
+
) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
|
108 |
+
return output, inbounds
|
109 |
+
|
110 |
+
return output # B, C, N
|
111 |
+
|
112 |
+
def smart_cat(tensor1, tensor2, dim):
|
113 |
+
if tensor1 is None:
|
114 |
+
return tensor2
|
115 |
+
return torch.cat([tensor1, tensor2], dim=dim)
|
116 |
+
|
117 |
+
|
118 |
+
def normalize_single(d):
|
119 |
+
# d is a whatever shape torch tensor
|
120 |
+
dmin = torch.min(d)
|
121 |
+
dmax = torch.max(d)
|
122 |
+
d = (d - dmin) / (EPS + (dmax - dmin))
|
123 |
+
return d
|
124 |
+
|
125 |
+
|
126 |
+
def normalize(d):
|
127 |
+
# d is B x whatever. normalize within each element of the batch
|
128 |
+
out = torch.zeros(d.size())
|
129 |
+
if d.is_cuda:
|
130 |
+
out = out.cuda()
|
131 |
+
B = list(d.size())[0]
|
132 |
+
for b in list(range(B)):
|
133 |
+
out[b] = normalize_single(d[b])
|
134 |
+
return out
|
135 |
+
|
136 |
+
|
137 |
+
def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"):
|
138 |
+
# returns a meshgrid sized B x Y x X
|
139 |
+
|
140 |
+
grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device))
|
141 |
+
grid_y = torch.reshape(grid_y, [1, Y, 1])
|
142 |
+
grid_y = grid_y.repeat(B, 1, X)
|
143 |
+
|
144 |
+
grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device))
|
145 |
+
grid_x = torch.reshape(grid_x, [1, 1, X])
|
146 |
+
grid_x = grid_x.repeat(B, Y, 1)
|
147 |
+
|
148 |
+
if stack:
|
149 |
+
# note we stack in xy order
|
150 |
+
# (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
|
151 |
+
grid = torch.stack([grid_x, grid_y], dim=-1)
|
152 |
+
return grid
|
153 |
+
else:
|
154 |
+
return grid_y, grid_x
|
155 |
+
|
156 |
+
|
157 |
+
def reduce_masked_mean(x, mask, dim=None, keepdim=False):
|
158 |
+
# x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting
|
159 |
+
# returns shape-1
|
160 |
+
# axis can be a list of axes
|
161 |
+
for (a, b) in zip(x.size(), mask.size()):
|
162 |
+
assert a == b # some shape mismatch!
|
163 |
+
prod = x * mask
|
164 |
+
if dim is None:
|
165 |
+
numer = torch.sum(prod)
|
166 |
+
denom = EPS + torch.sum(mask)
|
167 |
+
else:
|
168 |
+
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
|
169 |
+
denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim)
|
170 |
+
|
171 |
+
mean = numer / denom
|
172 |
+
return mean
|
173 |
+
|
174 |
+
|
175 |
+
def bilinear_sample2d(im, x, y, return_inbounds=False):
|
176 |
+
# x and y are each B, N
|
177 |
+
# output is B, C, N
|
178 |
+
if len(im.shape) == 5:
|
179 |
+
B, N, C, H, W = list(im.shape)
|
180 |
+
else:
|
181 |
+
B, C, H, W = list(im.shape)
|
182 |
+
N = list(x.shape)[1]
|
183 |
+
|
184 |
+
x = x.float()
|
185 |
+
y = y.float()
|
186 |
+
H_f = torch.tensor(H, dtype=torch.float32)
|
187 |
+
W_f = torch.tensor(W, dtype=torch.float32)
|
188 |
+
|
189 |
+
# inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()
|
190 |
+
|
191 |
+
max_y = (H_f - 1).int()
|
192 |
+
max_x = (W_f - 1).int()
|
193 |
+
|
194 |
+
x0 = torch.floor(x).int()
|
195 |
+
x1 = x0 + 1
|
196 |
+
y0 = torch.floor(y).int()
|
197 |
+
y1 = y0 + 1
|
198 |
+
|
199 |
+
x0_clip = torch.clamp(x0, 0, max_x)
|
200 |
+
x1_clip = torch.clamp(x1, 0, max_x)
|
201 |
+
y0_clip = torch.clamp(y0, 0, max_y)
|
202 |
+
y1_clip = torch.clamp(y1, 0, max_y)
|
203 |
+
dim2 = W
|
204 |
+
dim1 = W * H
|
205 |
+
|
206 |
+
base = torch.arange(0, B, dtype=torch.int64, device=x.device) * dim1
|
207 |
+
base = torch.reshape(base, [B, 1]).repeat([1, N])
|
208 |
+
|
209 |
+
base_y0 = base + y0_clip * dim2
|
210 |
+
base_y1 = base + y1_clip * dim2
|
211 |
+
|
212 |
+
idx_y0_x0 = base_y0 + x0_clip
|
213 |
+
idx_y0_x1 = base_y0 + x1_clip
|
214 |
+
idx_y1_x0 = base_y1 + x0_clip
|
215 |
+
idx_y1_x1 = base_y1 + x1_clip
|
216 |
+
|
217 |
+
# use the indices to lookup pixels in the flat image
|
218 |
+
# im is B x C x H x W
|
219 |
+
# move C out to last dim
|
220 |
+
if len(im.shape) == 5:
|
221 |
+
im_flat = (im.permute(0, 3, 4, 1, 2)).reshape(B * H * W, N, C)
|
222 |
+
i_y0_x0 = torch.diagonal(im_flat[idx_y0_x0.long()], dim1=1, dim2=2).permute(
|
223 |
+
0, 2, 1
|
224 |
+
)
|
225 |
+
i_y0_x1 = torch.diagonal(im_flat[idx_y0_x1.long()], dim1=1, dim2=2).permute(
|
226 |
+
0, 2, 1
|
227 |
+
)
|
228 |
+
i_y1_x0 = torch.diagonal(im_flat[idx_y1_x0.long()], dim1=1, dim2=2).permute(
|
229 |
+
0, 2, 1
|
230 |
+
)
|
231 |
+
i_y1_x1 = torch.diagonal(im_flat[idx_y1_x1.long()], dim1=1, dim2=2).permute(
|
232 |
+
0, 2, 1
|
233 |
+
)
|
234 |
+
else:
|
235 |
+
im_flat = (im.permute(0, 2, 3, 1)).reshape(B * H * W, C)
|
236 |
+
i_y0_x0 = im_flat[idx_y0_x0.long()]
|
237 |
+
i_y0_x1 = im_flat[idx_y0_x1.long()]
|
238 |
+
i_y1_x0 = im_flat[idx_y1_x0.long()]
|
239 |
+
i_y1_x1 = im_flat[idx_y1_x1.long()]
|
240 |
+
|
241 |
+
# Finally calculate interpolated values.
|
242 |
+
x0_f = x0.float()
|
243 |
+
x1_f = x1.float()
|
244 |
+
y0_f = y0.float()
|
245 |
+
y1_f = y1.float()
|
246 |
+
|
247 |
+
w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
|
248 |
+
w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
|
249 |
+
w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
|
250 |
+
w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)
|
251 |
+
|
252 |
+
output = (
|
253 |
+
w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1
|
254 |
+
)
|
255 |
+
# output is B*N x C
|
256 |
+
output = output.view(B, -1, C)
|
257 |
+
output = output.permute(0, 2, 1)
|
258 |
+
# output is B x C x N
|
259 |
+
|
260 |
+
if return_inbounds:
|
261 |
+
x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
|
262 |
+
y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
|
263 |
+
inbounds = (x_valid & y_valid).float()
|
264 |
+
inbounds = inbounds.reshape(
|
265 |
+
B, N
|
266 |
+
) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
|
267 |
+
return output, inbounds
|
268 |
+
|
269 |
+
return output # B, C, N
|
270 |
+
|
271 |
+
|
272 |
+
def procrustes_analysis(X0,X1,Weight): # [B,N,3]
|
273 |
+
# translation
|
274 |
+
t0 = X0.mean(dim=1,keepdim=True)
|
275 |
+
t1 = X1.mean(dim=1,keepdim=True)
|
276 |
+
X0c = X0-t0
|
277 |
+
X1c = X1-t1
|
278 |
+
# scale
|
279 |
+
# s0 = (X0c**2).sum(dim=-1).mean().sqrt()
|
280 |
+
# s1 = (X1c**2).sum(dim=-1).mean().sqrt()
|
281 |
+
# X0cs = X0c/s0
|
282 |
+
# X1cs = X1c/s1
|
283 |
+
# rotation (use double for SVD, float loses precision)
|
284 |
+
U,_,V = (X0c.t()@X1c).double().svd(some=True)
|
285 |
+
R = ([email protected]()).float()
|
286 |
+
if R.det()<0: R[2] *= -1
|
287 |
+
# align X1 to X0: X1to0 = (X1-t1)/@R.t()+t0
|
288 |
+
se3 = edict(t0=t0[0],t1=t1[0],R=R)
|
289 |
+
|
290 |
+
return se3
|
291 |
+
|
292 |
+
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
|
293 |
+
r"""Sample a tensor using bilinear interpolation
|
294 |
+
|
295 |
+
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
|
296 |
+
coordinates :attr:`coords` using bilinear interpolation. It is the same
|
297 |
+
as `torch.nn.functional.grid_sample()` but with a different coordinate
|
298 |
+
convention.
|
299 |
+
|
300 |
+
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
|
301 |
+
:math:`B` is the batch size, :math:`C` is the number of channels,
|
302 |
+
:math:`H` is the height of the image, and :math:`W` is the width of the
|
303 |
+
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
|
304 |
+
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
|
305 |
+
|
306 |
+
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
|
307 |
+
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
|
308 |
+
that in this case the order of the components is slightly different
|
309 |
+
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
|
310 |
+
|
311 |
+
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
|
312 |
+
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
|
313 |
+
left-most image pixel :math:`W-1` to the center of the right-most
|
314 |
+
pixel.
|
315 |
+
|
316 |
+
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
|
317 |
+
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
|
318 |
+
the left-most pixel :math:`W` to the right edge of the right-most
|
319 |
+
pixel.
|
320 |
+
|
321 |
+
Similar conventions apply to the :math:`y` for the range
|
322 |
+
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
|
323 |
+
:math:`[0,T-1]` and :math:`[0,T]`.
|
324 |
+
|
325 |
+
Args:
|
326 |
+
input (Tensor): batch of input images.
|
327 |
+
coords (Tensor): batch of coordinates.
|
328 |
+
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
|
329 |
+
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
|
330 |
+
|
331 |
+
Returns:
|
332 |
+
Tensor: sampled points.
|
333 |
+
"""
|
334 |
+
|
335 |
+
sizes = input.shape[2:]
|
336 |
+
|
337 |
+
assert len(sizes) in [2, 3]
|
338 |
+
|
339 |
+
if len(sizes) == 3:
|
340 |
+
# t x y -> x y t to match dimensions T H W in grid_sample
|
341 |
+
coords = coords[..., [1, 2, 0]]
|
342 |
+
|
343 |
+
if align_corners:
|
344 |
+
coords = coords * torch.tensor(
|
345 |
+
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
|
346 |
+
)
|
347 |
+
else:
|
348 |
+
coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)
|
349 |
+
|
350 |
+
coords -= 1
|
351 |
+
|
352 |
+
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
|
353 |
+
|
354 |
+
|
355 |
+
def sample_features4d(input, coords):
|
356 |
+
r"""Sample spatial features
|
357 |
+
|
358 |
+
`sample_features4d(input, coords)` samples the spatial features
|
359 |
+
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
|
360 |
+
|
361 |
+
The field is sampled at coordinates :attr:`coords` using bilinear
|
362 |
+
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
|
363 |
+
3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
|
364 |
+
same convention as :func:`bilinear_sampler` with `align_corners=True`.
|
365 |
+
|
366 |
+
The output tensor has one feature per point, and has shape :math:`(B,
|
367 |
+
R, C)`.
|
368 |
+
|
369 |
+
Args:
|
370 |
+
input (Tensor): spatial features.
|
371 |
+
coords (Tensor): points.
|
372 |
+
|
373 |
+
Returns:
|
374 |
+
Tensor: sampled features.
|
375 |
+
"""
|
376 |
+
|
377 |
+
B, _, _, _ = input.shape
|
378 |
+
|
379 |
+
# B R 2 -> B R 1 2
|
380 |
+
coords = coords.unsqueeze(2)
|
381 |
+
|
382 |
+
# B C R 1
|
383 |
+
feats = bilinear_sampler(input, coords)
|
384 |
+
|
385 |
+
return feats.permute(0, 2, 1, 3).view(
|
386 |
+
B, -1, feats.shape[1] * feats.shape[3]
|
387 |
+
) # B C R 1 -> B R C
|
388 |
+
|
389 |
+
|
390 |
+
def sample_features5d(input, coords):
|
391 |
+
r"""Sample spatio-temporal features
|
392 |
+
|
393 |
+
`sample_features5d(input, coords)` works in the same way as
|
394 |
+
:func:`sample_features4d` but for spatio-temporal features and points:
|
395 |
+
:attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
|
396 |
+
a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
|
397 |
+
x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
|
398 |
+
|
399 |
+
Args:
|
400 |
+
input (Tensor): spatio-temporal features.
|
401 |
+
coords (Tensor): spatio-temporal points.
|
402 |
+
|
403 |
+
Returns:
|
404 |
+
Tensor: sampled features.
|
405 |
+
"""
|
406 |
+
|
407 |
+
B, T, _, _, _ = input.shape
|
408 |
+
|
409 |
+
# B T C H W -> B C T H W
|
410 |
+
input = input.permute(0, 2, 1, 3, 4)
|
411 |
+
|
412 |
+
# B R1 R2 3 -> B R1 R2 1 3
|
413 |
+
coords = coords.unsqueeze(3)
|
414 |
+
|
415 |
+
# B C R1 R2 1
|
416 |
+
feats = bilinear_sampler(input, coords)
|
417 |
+
|
418 |
+
return feats.permute(0, 2, 3, 1, 4).view(
|
419 |
+
B, feats.shape[2], feats.shape[3], feats.shape[1]
|
420 |
+
) # B C R1 R2 1 -> B R1 R2 C
|
421 |
+
|
422 |
+
def vis_PCA(fmaps, save_dir):
|
423 |
+
"""
|
424 |
+
visualize the PCA of the feature maps
|
425 |
+
args:
|
426 |
+
fmaps: feature maps 1 C H W
|
427 |
+
save_dir: the directory to save the PCA visualization
|
428 |
+
"""
|
429 |
+
|
430 |
+
pca = PCA(n_components=3)
|
431 |
+
fmap_vis = fmaps[0,...]
|
432 |
+
fmap_vnorm = (
|
433 |
+
(fmap_vis-fmap_vis.min())/
|
434 |
+
(fmap_vis.max()-fmap_vis.min()))
|
435 |
+
H_vis, W_vis = fmap_vis.shape[1:]
|
436 |
+
fmap_vnorm = fmap_vnorm.reshape(fmap_vnorm.shape[0],
|
437 |
+
-1).permute(1,0)
|
438 |
+
fmap_pca = pca.fit_transform(fmap_vnorm.detach().cpu().numpy())
|
439 |
+
pca = fmap_pca.reshape(H_vis,W_vis,3)
|
440 |
+
plt.imsave(save_dir,
|
441 |
+
(
|
442 |
+
(pca-pca.min())/
|
443 |
+
(pca.max()-pca.min())
|
444 |
+
))
|
445 |
+
|
446 |
+
|
447 |
+
# debug=False
|
448 |
+
# if debug==True:
|
449 |
+
# pcd_idx = 60
|
450 |
+
# vis_PCA(fmapYZ[0,:1], "./yz.png")
|
451 |
+
# vis_PCA(fmapXZ[0,:1], "./xz.png")
|
452 |
+
# vis_PCA(fmaps[0,:1], "./xy.png")
|
453 |
+
# vis_PCA(fmaps[0,-1:], "./xy_.png")
|
454 |
+
# fxy_q = fxy[0,0,pcd_idx:pcd_idx+1, :, None, None]
|
455 |
+
# fyz_q = fyz[0,0,pcd_idx:pcd_idx+1, :, None, None]
|
456 |
+
# fxz_q = fxz[0,0,pcd_idx:pcd_idx+1, :, None, None]
|
457 |
+
# corr_map = (fxy_q*fmaps[0,-1:]).sum(dim=1)
|
458 |
+
# corr_map_yz = (fyz_q*fmapYZ[0,-1:]).sum(dim=1)
|
459 |
+
# corr_map_xz = (fxz_q*fmapXZ[0,-1:]).sum(dim=1)
|
460 |
+
# coord_last = coords[0,-1,pcd_idx:pcd_idx+1]
|
461 |
+
# coord_last_neigh = coords[0,-1, self.neigh_indx[pcd_idx]]
|
462 |
+
# depth_last = depths_dnG[-1,0]
|
463 |
+
# abs_res = (depth_last-coord_last[-1,-1]).abs()
|
464 |
+
# abs_res = (abs_res - abs_res.min())/(abs_res.max()-abs_res.min())
|
465 |
+
# res_dp = torch.exp(-abs_res)
|
466 |
+
# enhance_corr = res_dp*corr_map
|
467 |
+
# plt.imsave("./res.png", res_dp.detach().cpu().numpy())
|
468 |
+
# plt.imsave("./enhance_corr.png", enhance_corr[0].detach().cpu().numpy())
|
469 |
+
# plt.imsave("./corr_map.png", corr_map[0].detach().cpu().numpy())
|
470 |
+
# plt.imsave("./corr_map_yz.png", corr_map_yz[0].detach().cpu().numpy())
|
471 |
+
# plt.imsave("./corr_map_xz.png", corr_map_xz[0].detach().cpu().numpy())
|
472 |
+
# img_feat = cv2.imread("./xy.png")
|
473 |
+
# cv2.circle(img_feat, (int(coord_last[0,0]), int(coord_last[0,1])), 2, (0, 0, 255), -1)
|
474 |
+
# for p_i in coord_last_neigh:
|
475 |
+
# cv2.circle(img_feat, (int(p_i[0]), int(p_i[1])), 1, (0, 255, 0), -1)
|
476 |
+
# cv2.imwrite("./xy_coord.png", img_feat)
|
477 |
+
# import ipdb; ipdb.set_trace()
|
models/spatracker/models/core/spatracker/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
models/spatracker/models/core/spatracker/blocks.py
ADDED
@@ -0,0 +1,999 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch.cuda.amp import autocast
|
11 |
+
from einops import rearrange
|
12 |
+
import collections
|
13 |
+
from functools import partial
|
14 |
+
from itertools import repeat
|
15 |
+
import torchvision.models as tvm
|
16 |
+
|
17 |
+
from models.spatracker.models.core.spatracker.vit.encoder import ImageEncoderViT as vitEnc
|
18 |
+
from models.spatracker.models.core.spatracker.dpt.models import DPTEncoder
|
19 |
+
from models.spatracker.models.core.spatracker.loftr import LocalFeatureTransformer
|
20 |
+
# from models.monoD.depth_anything.dpt import DPTHeadEnc, DPTHead
|
21 |
+
|
22 |
+
# From PyTorch internals
|
23 |
+
def _ntuple(n):
|
24 |
+
def parse(x):
|
25 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
26 |
+
return tuple(x)
|
27 |
+
return tuple(repeat(x, n))
|
28 |
+
|
29 |
+
return parse
|
30 |
+
|
31 |
+
|
32 |
+
def exists(val):
|
33 |
+
return val is not None
|
34 |
+
|
35 |
+
|
36 |
+
def default(val, d):
|
37 |
+
return val if exists(val) else d
|
38 |
+
|
39 |
+
|
40 |
+
to_2tuple = _ntuple(2)
|
41 |
+
|
42 |
+
|
43 |
+
class Mlp(nn.Module):
|
44 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
45 |
+
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
in_features,
|
49 |
+
hidden_features=None,
|
50 |
+
out_features=None,
|
51 |
+
act_layer=nn.GELU,
|
52 |
+
norm_layer=None,
|
53 |
+
bias=True,
|
54 |
+
drop=0.0,
|
55 |
+
use_conv=False,
|
56 |
+
):
|
57 |
+
super().__init__()
|
58 |
+
out_features = out_features or in_features
|
59 |
+
hidden_features = hidden_features or in_features
|
60 |
+
bias = to_2tuple(bias)
|
61 |
+
drop_probs = to_2tuple(drop)
|
62 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
63 |
+
|
64 |
+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
65 |
+
self.act = act_layer()
|
66 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
67 |
+
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
68 |
+
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
69 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
x = self.fc1(x)
|
73 |
+
x = self.act(x)
|
74 |
+
x = self.drop1(x)
|
75 |
+
x = self.fc2(x)
|
76 |
+
x = self.drop2(x)
|
77 |
+
return x
|
78 |
+
|
79 |
+
class Attention(nn.Module):
|
80 |
+
def __init__(self, query_dim, context_dim=None,
|
81 |
+
num_heads=8, dim_head=48, qkv_bias=False, flash=False):
|
82 |
+
super().__init__()
|
83 |
+
inner_dim = self.inner_dim = dim_head * num_heads
|
84 |
+
context_dim = default(context_dim, query_dim)
|
85 |
+
self.scale = dim_head**-0.5
|
86 |
+
self.heads = num_heads
|
87 |
+
self.flash = flash
|
88 |
+
|
89 |
+
self.qkv = nn.Linear(query_dim, inner_dim*3, bias=qkv_bias)
|
90 |
+
self.proj = nn.Linear(inner_dim, query_dim)
|
91 |
+
|
92 |
+
def forward(self, x, context=None, attn_bias=None):
|
93 |
+
B, N1, _ = x.shape
|
94 |
+
C = self.inner_dim
|
95 |
+
h = self.heads
|
96 |
+
# q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
|
97 |
+
# k, v = self.to_kv(context).chunk(2, dim=-1)
|
98 |
+
# context = default(context, x)
|
99 |
+
|
100 |
+
qkv = self.qkv(x).reshape(B, N1, 3, h, C // h)
|
101 |
+
q, k, v = qkv[:,:, 0], qkv[:,:, 1], qkv[:,:, 2]
|
102 |
+
N2 = x.shape[1]
|
103 |
+
|
104 |
+
k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
|
105 |
+
v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
|
106 |
+
q = q.reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
|
107 |
+
if self.flash==False:
|
108 |
+
sim = (q @ k.transpose(-2, -1)) * self.scale
|
109 |
+
if attn_bias is not None:
|
110 |
+
sim = sim + attn_bias
|
111 |
+
attn = sim.softmax(dim=-1)
|
112 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
|
113 |
+
else:
|
114 |
+
input_args = [x.half().contiguous() for x in [q, k, v]]
|
115 |
+
x = F.scaled_dot_product_attention(*input_args).permute(0,2,1,3).reshape(B,N1,-1) # type: ignore
|
116 |
+
|
117 |
+
# return self.to_out(x.float())
|
118 |
+
return self.proj(x.float())
|
119 |
+
|
120 |
+
class ResidualBlock(nn.Module):
|
121 |
+
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
122 |
+
super(ResidualBlock, self).__init__()
|
123 |
+
|
124 |
+
self.conv1 = nn.Conv2d(
|
125 |
+
in_planes,
|
126 |
+
planes,
|
127 |
+
kernel_size=3,
|
128 |
+
padding=1,
|
129 |
+
stride=stride,
|
130 |
+
padding_mode="zeros",
|
131 |
+
)
|
132 |
+
self.conv2 = nn.Conv2d(
|
133 |
+
planes, planes, kernel_size=3, padding=1, padding_mode="zeros"
|
134 |
+
)
|
135 |
+
self.relu = nn.ReLU(inplace=True)
|
136 |
+
|
137 |
+
num_groups = planes // 8
|
138 |
+
|
139 |
+
if norm_fn == "group":
|
140 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
141 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
142 |
+
if not stride == 1:
|
143 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
144 |
+
|
145 |
+
elif norm_fn == "batch":
|
146 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
147 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
148 |
+
if not stride == 1:
|
149 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
150 |
+
|
151 |
+
elif norm_fn == "instance":
|
152 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
153 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
154 |
+
if not stride == 1:
|
155 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
156 |
+
|
157 |
+
elif norm_fn == "none":
|
158 |
+
self.norm1 = nn.Sequential()
|
159 |
+
self.norm2 = nn.Sequential()
|
160 |
+
if not stride == 1:
|
161 |
+
self.norm3 = nn.Sequential()
|
162 |
+
|
163 |
+
if stride == 1:
|
164 |
+
self.downsample = None
|
165 |
+
|
166 |
+
else:
|
167 |
+
self.downsample = nn.Sequential(
|
168 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
169 |
+
)
|
170 |
+
|
171 |
+
def forward(self, x):
|
172 |
+
y = x
|
173 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
174 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
175 |
+
|
176 |
+
if self.downsample is not None:
|
177 |
+
x = self.downsample(x)
|
178 |
+
|
179 |
+
return self.relu(x + y)
|
180 |
+
|
181 |
+
|
182 |
+
class BasicEncoder(nn.Module):
|
183 |
+
def __init__(
|
184 |
+
self, input_dim=3, output_dim=128, stride=8, norm_fn="batch", dropout=0.0,
|
185 |
+
Embed3D=False
|
186 |
+
):
|
187 |
+
super(BasicEncoder, self).__init__()
|
188 |
+
self.stride = stride
|
189 |
+
self.norm_fn = norm_fn
|
190 |
+
self.in_planes = 64
|
191 |
+
|
192 |
+
if self.norm_fn == "group":
|
193 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes)
|
194 |
+
self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2)
|
195 |
+
|
196 |
+
elif self.norm_fn == "batch":
|
197 |
+
self.norm1 = nn.BatchNorm2d(self.in_planes)
|
198 |
+
self.norm2 = nn.BatchNorm2d(output_dim * 2)
|
199 |
+
|
200 |
+
elif self.norm_fn == "instance":
|
201 |
+
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
202 |
+
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
203 |
+
|
204 |
+
elif self.norm_fn == "none":
|
205 |
+
self.norm1 = nn.Sequential()
|
206 |
+
|
207 |
+
self.conv1 = nn.Conv2d(
|
208 |
+
input_dim,
|
209 |
+
self.in_planes,
|
210 |
+
kernel_size=7,
|
211 |
+
stride=2,
|
212 |
+
padding=3,
|
213 |
+
padding_mode="zeros",
|
214 |
+
)
|
215 |
+
self.relu1 = nn.ReLU(inplace=True)
|
216 |
+
|
217 |
+
self.shallow = False
|
218 |
+
if self.shallow:
|
219 |
+
self.layer1 = self._make_layer(64, stride=1)
|
220 |
+
self.layer2 = self._make_layer(96, stride=2)
|
221 |
+
self.layer3 = self._make_layer(128, stride=2)
|
222 |
+
self.conv2 = nn.Conv2d(128 + 96 + 64, output_dim, kernel_size=1)
|
223 |
+
else:
|
224 |
+
if Embed3D:
|
225 |
+
self.conv_fuse = nn.Conv2d(64+63,
|
226 |
+
self.in_planes, kernel_size=3, padding=1)
|
227 |
+
self.layer1 = self._make_layer(64, stride=1)
|
228 |
+
self.layer2 = self._make_layer(96, stride=2)
|
229 |
+
self.layer3 = self._make_layer(128, stride=2)
|
230 |
+
self.layer4 = self._make_layer(128, stride=2)
|
231 |
+
self.conv2 = nn.Conv2d(
|
232 |
+
128 + 128 + 96 + 64,
|
233 |
+
output_dim * 2,
|
234 |
+
kernel_size=3,
|
235 |
+
padding=1,
|
236 |
+
padding_mode="zeros",
|
237 |
+
)
|
238 |
+
self.relu2 = nn.ReLU(inplace=True)
|
239 |
+
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
|
240 |
+
|
241 |
+
self.dropout = None
|
242 |
+
if dropout > 0:
|
243 |
+
self.dropout = nn.Dropout2d(p=dropout)
|
244 |
+
|
245 |
+
for m in self.modules():
|
246 |
+
if isinstance(m, nn.Conv2d):
|
247 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out",
|
248 |
+
nonlinearity="relu")
|
249 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
250 |
+
if m.weight is not None:
|
251 |
+
nn.init.constant_(m.weight, 1)
|
252 |
+
if m.bias is not None:
|
253 |
+
nn.init.constant_(m.bias, 0)
|
254 |
+
|
255 |
+
def _make_layer(self, dim, stride=1):
|
256 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
257 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
258 |
+
layers = (layer1, layer2)
|
259 |
+
|
260 |
+
self.in_planes = dim
|
261 |
+
return nn.Sequential(*layers)
|
262 |
+
|
263 |
+
def forward(self, x, feat_PE=None):
|
264 |
+
_, _, H, W = x.shape
|
265 |
+
|
266 |
+
x = self.conv1(x)
|
267 |
+
x = self.norm1(x)
|
268 |
+
x = self.relu1(x)
|
269 |
+
|
270 |
+
if self.shallow:
|
271 |
+
a = self.layer1(x)
|
272 |
+
b = self.layer2(a)
|
273 |
+
c = self.layer3(b)
|
274 |
+
a = F.interpolate(
|
275 |
+
a,
|
276 |
+
(H // self.stride, W // self.stride),
|
277 |
+
mode="bilinear",
|
278 |
+
align_corners=True,
|
279 |
+
)
|
280 |
+
b = F.interpolate(
|
281 |
+
b,
|
282 |
+
(H // self.stride, W // self.stride),
|
283 |
+
mode="bilinear",
|
284 |
+
align_corners=True,
|
285 |
+
)
|
286 |
+
c = F.interpolate(
|
287 |
+
c,
|
288 |
+
(H // self.stride, W // self.stride),
|
289 |
+
mode="bilinear",
|
290 |
+
align_corners=True,
|
291 |
+
)
|
292 |
+
x = self.conv2(torch.cat([a, b, c], dim=1))
|
293 |
+
else:
|
294 |
+
if feat_PE is not None:
|
295 |
+
x = self.conv_fuse(torch.cat([x, feat_PE], dim=1))
|
296 |
+
a = self.layer1(x)
|
297 |
+
else:
|
298 |
+
a = self.layer1(x)
|
299 |
+
b = self.layer2(a)
|
300 |
+
c = self.layer3(b)
|
301 |
+
d = self.layer4(c)
|
302 |
+
a = F.interpolate(
|
303 |
+
a,
|
304 |
+
(H // self.stride, W // self.stride),
|
305 |
+
mode="bilinear",
|
306 |
+
align_corners=True,
|
307 |
+
)
|
308 |
+
b = F.interpolate(
|
309 |
+
b,
|
310 |
+
(H // self.stride, W // self.stride),
|
311 |
+
mode="bilinear",
|
312 |
+
align_corners=True,
|
313 |
+
)
|
314 |
+
c = F.interpolate(
|
315 |
+
c,
|
316 |
+
(H // self.stride, W // self.stride),
|
317 |
+
mode="bilinear",
|
318 |
+
align_corners=True,
|
319 |
+
)
|
320 |
+
d = F.interpolate(
|
321 |
+
d,
|
322 |
+
(H // self.stride, W // self.stride),
|
323 |
+
mode="bilinear",
|
324 |
+
align_corners=True,
|
325 |
+
)
|
326 |
+
x = self.conv2(torch.cat([a, b, c, d], dim=1))
|
327 |
+
x = self.norm2(x)
|
328 |
+
x = self.relu2(x)
|
329 |
+
x = self.conv3(x)
|
330 |
+
|
331 |
+
if self.training and self.dropout is not None:
|
332 |
+
x = self.dropout(x)
|
333 |
+
return x
|
334 |
+
|
335 |
+
class VitEncoder(nn.Module):
|
336 |
+
def __init__(self, input_dim=4, output_dim=128, stride=4):
|
337 |
+
super(VitEncoder, self).__init__()
|
338 |
+
self.vit = vitEnc(img_size=512,
|
339 |
+
depth=6, num_heads=8, in_chans=input_dim,
|
340 |
+
out_chans=output_dim,embed_dim=384).cuda()
|
341 |
+
self.stride = stride
|
342 |
+
def forward(self, x):
|
343 |
+
T, C, H, W = x.shape
|
344 |
+
x_resize = F.interpolate(x.view(-1, C, H, W), size=(512, 512),
|
345 |
+
mode='bilinear', align_corners=False)
|
346 |
+
x_resize = self.vit(x_resize)
|
347 |
+
x = F.interpolate(x_resize, size=(H//self.stride, W//self.stride),
|
348 |
+
mode='bilinear', align_corners=False)
|
349 |
+
return x
|
350 |
+
|
351 |
+
class DPTEnc(nn.Module):
|
352 |
+
def __init__(self, input_dim=3, output_dim=128, stride=2):
|
353 |
+
super(DPTEnc, self).__init__()
|
354 |
+
self.dpt = DPTEncoder()
|
355 |
+
self.stride = stride
|
356 |
+
def forward(self, x):
|
357 |
+
T, C, H, W = x.shape
|
358 |
+
x = (x-0.5)/0.5
|
359 |
+
x_resize = F.interpolate(x.view(-1, C, H, W), size=(384, 384),
|
360 |
+
mode='bilinear', align_corners=False)
|
361 |
+
x_resize = self.dpt(x_resize)
|
362 |
+
x = F.interpolate(x_resize, size=(H//self.stride, W//self.stride),
|
363 |
+
mode='bilinear', align_corners=False)
|
364 |
+
return x
|
365 |
+
|
366 |
+
# class DPT_DINOv2(nn.Module):
|
367 |
+
# def __init__(self, encoder='vits', features=64, out_channels=[48, 96, 192, 384],
|
368 |
+
# use_bn=True, use_clstoken=False, localhub=True, stride=2, enc_only=True):
|
369 |
+
# super(DPT_DINOv2, self).__init__()
|
370 |
+
# self.stride = stride
|
371 |
+
# self.enc_only = enc_only
|
372 |
+
# assert encoder in ['vits', 'vitb', 'vitl']
|
373 |
+
|
374 |
+
# if localhub:
|
375 |
+
# self.pretrained = torch.hub.load('models/torchhub/facebookresearch_dinov2_main', 'dinov2_{:}14'.format(encoder), source='local', pretrained=False)
|
376 |
+
# else:
|
377 |
+
# self.pretrained = torch.hub.load('facebookresearch/dinov2', 'dinov2_{:}14'.format(encoder))
|
378 |
+
|
379 |
+
# state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vits14_pretrain.pth")
|
380 |
+
# self.pretrained.load_state_dict(state_dict, strict=True)
|
381 |
+
# self.pretrained.requires_grad_(False)
|
382 |
+
# dim = self.pretrained.blocks[0].attn.qkv.in_features
|
383 |
+
# if enc_only == True:
|
384 |
+
# out_channels=[128, 128, 128, 128]
|
385 |
+
|
386 |
+
# self.DPThead = DPTHeadEnc(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
|
387 |
+
|
388 |
+
|
389 |
+
# def forward(self, x):
|
390 |
+
# mean_ = torch.tensor([0.485, 0.456, 0.406],
|
391 |
+
# device=x.device).view(1, 3, 1, 1)
|
392 |
+
# std_ = torch.tensor([0.229, 0.224, 0.225],
|
393 |
+
# device=x.device).view(1, 3, 1, 1)
|
394 |
+
# x = (x+1)/2
|
395 |
+
# x = (x - mean_)/std_
|
396 |
+
# h, w = x.shape[-2:]
|
397 |
+
# h_re, w_re = 560, 560
|
398 |
+
# x_resize = F.interpolate(x, size=(h_re, w_re),
|
399 |
+
# mode='bilinear', align_corners=False)
|
400 |
+
# with torch.no_grad():
|
401 |
+
# features = self.pretrained.get_intermediate_layers(x_resize, 4, return_class_token=True)
|
402 |
+
# patch_h, patch_w = h_re // 14, w_re // 14
|
403 |
+
# feat = self.DPThead(features, patch_h, patch_w, self.enc_only)
|
404 |
+
# feat = F.interpolate(feat, size=(h//self.stride, w//self.stride), mode="bilinear", align_corners=True)
|
405 |
+
|
406 |
+
# return feat
|
407 |
+
|
408 |
+
|
409 |
+
class VGG19(nn.Module):
|
410 |
+
def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None:
|
411 |
+
super().__init__()
|
412 |
+
self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
|
413 |
+
self.amp = amp
|
414 |
+
self.amp_dtype = amp_dtype
|
415 |
+
|
416 |
+
def forward(self, x, **kwargs):
|
417 |
+
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
418 |
+
feats = {}
|
419 |
+
scale = 1
|
420 |
+
for layer in self.layers:
|
421 |
+
if isinstance(layer, nn.MaxPool2d):
|
422 |
+
feats[scale] = x
|
423 |
+
scale = scale*2
|
424 |
+
x = layer(x)
|
425 |
+
return feats
|
426 |
+
|
427 |
+
class CNNandDinov2(nn.Module):
|
428 |
+
def __init__(self, cnn_kwargs = None, amp = True, amp_dtype = torch.float16):
|
429 |
+
super().__init__()
|
430 |
+
# in case the Internet connection is not stable, please load the DINOv2 locally
|
431 |
+
self.dinov2_vitl14 = torch.hub.load('models/torchhub/facebookresearch_dinov2_main',
|
432 |
+
'dinov2_{:}14'.format("vitl"), source='local', pretrained=False)
|
433 |
+
|
434 |
+
state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vitl14_pretrain.pth")
|
435 |
+
self.dinov2_vitl14.load_state_dict(state_dict, strict=True)
|
436 |
+
|
437 |
+
|
438 |
+
cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {}
|
439 |
+
self.cnn = VGG19(**cnn_kwargs)
|
440 |
+
self.amp = amp
|
441 |
+
self.amp_dtype = amp_dtype
|
442 |
+
if self.amp:
|
443 |
+
dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
|
444 |
+
self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
|
445 |
+
|
446 |
+
|
447 |
+
def train(self, mode: bool = True):
|
448 |
+
return self.cnn.train(mode)
|
449 |
+
|
450 |
+
def forward(self, x, upsample = False):
|
451 |
+
B,C,H,W = x.shape
|
452 |
+
feature_pyramid = self.cnn(x)
|
453 |
+
|
454 |
+
if not upsample:
|
455 |
+
with torch.no_grad():
|
456 |
+
if self.dinov2_vitl14[0].device != x.device:
|
457 |
+
self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype)
|
458 |
+
dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype))
|
459 |
+
features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14)
|
460 |
+
del dinov2_features_16
|
461 |
+
feature_pyramid[16] = features_16
|
462 |
+
return feature_pyramid
|
463 |
+
|
464 |
+
class Dinov2(nn.Module):
|
465 |
+
def __init__(self, amp = True, amp_dtype = torch.float16):
|
466 |
+
super().__init__()
|
467 |
+
# in case the Internet connection is not stable, please load the DINOv2 locally
|
468 |
+
self.dinov2_vitl14 = torch.hub.load('models/torchhub/facebookresearch_dinov2_main',
|
469 |
+
'dinov2_{:}14'.format("vitl"), source='local', pretrained=False)
|
470 |
+
|
471 |
+
state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vitl14_pretrain.pth")
|
472 |
+
self.dinov2_vitl14.load_state_dict(state_dict, strict=True)
|
473 |
+
|
474 |
+
self.amp = amp
|
475 |
+
self.amp_dtype = amp_dtype
|
476 |
+
if self.amp:
|
477 |
+
self.dinov2_vitl14 = self.dinov2_vitl14.to(self.amp_dtype)
|
478 |
+
|
479 |
+
def forward(self, x, upsample = False):
|
480 |
+
B,C,H,W = x.shape
|
481 |
+
mean_ = torch.tensor([0.485, 0.456, 0.406],
|
482 |
+
device=x.device).view(1, 3, 1, 1)
|
483 |
+
std_ = torch.tensor([0.229, 0.224, 0.225],
|
484 |
+
device=x.device).view(1, 3, 1, 1)
|
485 |
+
x = (x+1)/2
|
486 |
+
x = (x - mean_)/std_
|
487 |
+
h_re, w_re = 560, 560
|
488 |
+
x_resize = F.interpolate(x, size=(h_re, w_re),
|
489 |
+
mode='bilinear', align_corners=True)
|
490 |
+
if not upsample:
|
491 |
+
with torch.no_grad():
|
492 |
+
dinov2_features_16 = self.dinov2_vitl14.forward_features(x_resize.to(self.amp_dtype))
|
493 |
+
features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,h_re//14, w_re//14)
|
494 |
+
del dinov2_features_16
|
495 |
+
features_16 = F.interpolate(features_16, size=(H//8, W//8), mode="bilinear", align_corners=True)
|
496 |
+
return features_16
|
497 |
+
|
498 |
+
class AttnBlock(nn.Module):
|
499 |
+
"""
|
500 |
+
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
501 |
+
"""
|
502 |
+
|
503 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0,
|
504 |
+
flash=False, **block_kwargs):
|
505 |
+
super().__init__()
|
506 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
507 |
+
self.flash=flash
|
508 |
+
|
509 |
+
self.attn = Attention(
|
510 |
+
hidden_size, num_heads=num_heads, qkv_bias=True, flash=flash,
|
511 |
+
**block_kwargs
|
512 |
+
)
|
513 |
+
|
514 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
515 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
516 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
517 |
+
self.mlp = Mlp(
|
518 |
+
in_features=hidden_size,
|
519 |
+
hidden_features=mlp_hidden_dim,
|
520 |
+
act_layer=approx_gelu,
|
521 |
+
drop=0,
|
522 |
+
)
|
523 |
+
def forward(self, x):
|
524 |
+
x = x + self.attn(self.norm1(x))
|
525 |
+
x = x + self.mlp(self.norm2(x))
|
526 |
+
return x
|
527 |
+
|
528 |
+
class CrossAttnBlock(nn.Module):
|
529 |
+
def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0,
|
530 |
+
flash=True, **block_kwargs):
|
531 |
+
super().__init__()
|
532 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
533 |
+
self.norm_context = nn.LayerNorm(hidden_size)
|
534 |
+
|
535 |
+
self.cross_attn = Attention(
|
536 |
+
hidden_size, context_dim=context_dim,
|
537 |
+
num_heads=num_heads, qkv_bias=True, **block_kwargs, flash=flash
|
538 |
+
|
539 |
+
)
|
540 |
+
|
541 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
542 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
543 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
544 |
+
self.mlp = Mlp(
|
545 |
+
in_features=hidden_size,
|
546 |
+
hidden_features=mlp_hidden_dim,
|
547 |
+
act_layer=approx_gelu,
|
548 |
+
drop=0,
|
549 |
+
)
|
550 |
+
|
551 |
+
def forward(self, x, context):
|
552 |
+
with autocast():
|
553 |
+
x = x + self.cross_attn(
|
554 |
+
self.norm1(x), self.norm_context(context)
|
555 |
+
)
|
556 |
+
x = x + self.mlp(self.norm2(x))
|
557 |
+
return x
|
558 |
+
|
559 |
+
|
560 |
+
def bilinear_sampler(img, coords, mode="bilinear", mask=False):
|
561 |
+
"""Wrapper for grid_sample, uses pixel coordinates"""
|
562 |
+
H, W = img.shape[-2:]
|
563 |
+
xgrid, ygrid = coords.split([1, 1], dim=-1)
|
564 |
+
# go to 0,1 then 0,2 then -1,1
|
565 |
+
xgrid = 2 * xgrid / (W - 1) - 1
|
566 |
+
ygrid = 2 * ygrid / (H - 1) - 1
|
567 |
+
|
568 |
+
grid = torch.cat([xgrid, ygrid], dim=-1)
|
569 |
+
img = F.grid_sample(img, grid, align_corners=True)
|
570 |
+
|
571 |
+
if mask:
|
572 |
+
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
573 |
+
return img, mask.float()
|
574 |
+
|
575 |
+
return img
|
576 |
+
|
577 |
+
|
578 |
+
class CorrBlock:
|
579 |
+
def __init__(self, fmaps, num_levels=4, radius=4, depths_dnG=None):
|
580 |
+
B, S, C, H_prev, W_prev = fmaps.shape
|
581 |
+
self.S, self.C, self.H, self.W = S, C, H_prev, W_prev
|
582 |
+
|
583 |
+
self.num_levels = num_levels
|
584 |
+
self.radius = radius
|
585 |
+
self.fmaps_pyramid = []
|
586 |
+
self.depth_pyramid = []
|
587 |
+
self.fmaps_pyramid.append(fmaps)
|
588 |
+
if depths_dnG is not None:
|
589 |
+
self.depth_pyramid.append(depths_dnG)
|
590 |
+
for i in range(self.num_levels - 1):
|
591 |
+
if depths_dnG is not None:
|
592 |
+
depths_dnG_ = depths_dnG.reshape(B * S, 1, H_prev, W_prev)
|
593 |
+
depths_dnG_ = F.avg_pool2d(depths_dnG_, 2, stride=2)
|
594 |
+
_, _, H, W = depths_dnG_.shape
|
595 |
+
depths_dnG = depths_dnG_.reshape(B, S, 1, H, W)
|
596 |
+
self.depth_pyramid.append(depths_dnG)
|
597 |
+
fmaps_ = fmaps.reshape(B * S, C, H_prev, W_prev)
|
598 |
+
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
|
599 |
+
_, _, H, W = fmaps_.shape
|
600 |
+
fmaps = fmaps_.reshape(B, S, C, H, W)
|
601 |
+
H_prev = H
|
602 |
+
W_prev = W
|
603 |
+
self.fmaps_pyramid.append(fmaps)
|
604 |
+
|
605 |
+
def sample(self, coords):
|
606 |
+
r = self.radius
|
607 |
+
B, S, N, D = coords.shape
|
608 |
+
assert D == 2
|
609 |
+
|
610 |
+
H, W = self.H, self.W
|
611 |
+
out_pyramid = []
|
612 |
+
for i in range(self.num_levels):
|
613 |
+
corrs = self.corrs_pyramid[i] # B, S, N, H, W
|
614 |
+
_, _, _, H, W = corrs.shape
|
615 |
+
|
616 |
+
dx = torch.linspace(-r, r, 2 * r + 1)
|
617 |
+
dy = torch.linspace(-r, r, 2 * r + 1)
|
618 |
+
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
|
619 |
+
coords.device
|
620 |
+
)
|
621 |
+
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
|
622 |
+
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
623 |
+
coords_lvl = centroid_lvl + delta_lvl
|
624 |
+
corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl)
|
625 |
+
corrs = corrs.view(B, S, N, -1)
|
626 |
+
out_pyramid.append(corrs)
|
627 |
+
|
628 |
+
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
|
629 |
+
return out.contiguous().float()
|
630 |
+
|
631 |
+
def corr(self, targets):
|
632 |
+
B, S, N, C = targets.shape
|
633 |
+
assert C == self.C
|
634 |
+
assert S == self.S
|
635 |
+
|
636 |
+
fmap1 = targets
|
637 |
+
|
638 |
+
self.corrs_pyramid = []
|
639 |
+
for fmaps in self.fmaps_pyramid:
|
640 |
+
_, _, _, H, W = fmaps.shape
|
641 |
+
fmap2s = fmaps.view(B, S, C, H * W)
|
642 |
+
corrs = torch.matmul(fmap1, fmap2s)
|
643 |
+
corrs = corrs.view(B, S, N, H, W)
|
644 |
+
corrs = corrs / torch.sqrt(torch.tensor(C).float())
|
645 |
+
self.corrs_pyramid.append(corrs)
|
646 |
+
|
647 |
+
def corr_sample(self, targets, coords, coords_dp=None):
|
648 |
+
B, S, N, C = targets.shape
|
649 |
+
r = self.radius
|
650 |
+
Dim_c = (2*r+1)**2
|
651 |
+
assert C == self.C
|
652 |
+
assert S == self.S
|
653 |
+
|
654 |
+
out_pyramid = []
|
655 |
+
out_pyramid_dp = []
|
656 |
+
for i in range(self.num_levels):
|
657 |
+
dx = torch.linspace(-r, r, 2 * r + 1)
|
658 |
+
dy = torch.linspace(-r, r, 2 * r + 1)
|
659 |
+
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
|
660 |
+
coords.device
|
661 |
+
)
|
662 |
+
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
|
663 |
+
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
664 |
+
coords_lvl = centroid_lvl + delta_lvl
|
665 |
+
fmaps = self.fmaps_pyramid[i]
|
666 |
+
_, _, _, H, W = fmaps.shape
|
667 |
+
fmap2s = fmaps.view(B*S, C, H, W)
|
668 |
+
if len(self.depth_pyramid)>0:
|
669 |
+
depths_dnG_i = self.depth_pyramid[i]
|
670 |
+
depths_dnG_i = depths_dnG_i.view(B*S, 1, H, W)
|
671 |
+
dnG_sample = bilinear_sampler(depths_dnG_i, coords_lvl.view(B*S,1,N*Dim_c,2))
|
672 |
+
dp_corrs = (dnG_sample.view(B*S,N,-1) - coords_dp[0]).abs()/coords_dp[0]
|
673 |
+
out_pyramid_dp.append(dp_corrs)
|
674 |
+
fmap2s_sample = bilinear_sampler(fmap2s, coords_lvl.view(B*S,1,N*Dim_c,2))
|
675 |
+
fmap2s_sample = fmap2s_sample.permute(0, 3, 1, 2) # B*S, N*Dim_c, C, -1
|
676 |
+
corrs = torch.matmul(targets.reshape(B*S*N, 1, -1), fmap2s_sample.reshape(B*S*N, Dim_c, -1).permute(0, 2, 1))
|
677 |
+
corrs = corrs / torch.sqrt(torch.tensor(C).float())
|
678 |
+
corrs = corrs.view(B, S, N, -1)
|
679 |
+
out_pyramid.append(corrs)
|
680 |
+
|
681 |
+
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
|
682 |
+
if len(self.depth_pyramid)>0:
|
683 |
+
out_dp = torch.cat(out_pyramid_dp, dim=-1)
|
684 |
+
self.fcorrD = out_dp.contiguous().float()
|
685 |
+
else:
|
686 |
+
self.fcorrD = torch.zeros_like(out).contiguous().float()
|
687 |
+
return out.contiguous().float()
|
688 |
+
|
689 |
+
|
690 |
+
class EUpdateFormer(nn.Module):
|
691 |
+
"""
|
692 |
+
Transformer model that updates track estimates.
|
693 |
+
"""
|
694 |
+
|
695 |
+
def __init__(
|
696 |
+
self,
|
697 |
+
space_depth=12,
|
698 |
+
time_depth=12,
|
699 |
+
input_dim=320,
|
700 |
+
hidden_size=384,
|
701 |
+
num_heads=8,
|
702 |
+
output_dim=130,
|
703 |
+
mlp_ratio=4.0,
|
704 |
+
vq_depth=3,
|
705 |
+
add_space_attn=True,
|
706 |
+
add_time_attn=True,
|
707 |
+
flash=True
|
708 |
+
):
|
709 |
+
super().__init__()
|
710 |
+
self.out_channels = 2
|
711 |
+
self.num_heads = num_heads
|
712 |
+
self.hidden_size = hidden_size
|
713 |
+
self.add_space_attn = add_space_attn
|
714 |
+
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
|
715 |
+
self.flash = flash
|
716 |
+
self.flow_head = nn.Sequential(
|
717 |
+
nn.Linear(hidden_size, output_dim, bias=True),
|
718 |
+
nn.ReLU(inplace=True),
|
719 |
+
nn.Linear(output_dim, output_dim, bias=True),
|
720 |
+
nn.ReLU(inplace=True),
|
721 |
+
nn.Linear(output_dim, output_dim, bias=True)
|
722 |
+
)
|
723 |
+
|
724 |
+
cross_attn_kwargs = {
|
725 |
+
"d_model": 384,
|
726 |
+
"nhead": 4,
|
727 |
+
"layer_names": ['self', 'cross'] * 3,
|
728 |
+
}
|
729 |
+
self.gnn = LocalFeatureTransformer(cross_attn_kwargs)
|
730 |
+
|
731 |
+
# Attention Modules in the temporal dimension
|
732 |
+
self.time_blocks = nn.ModuleList(
|
733 |
+
[
|
734 |
+
AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, flash=flash) if add_time_attn else nn.Identity()
|
735 |
+
for _ in range(time_depth)
|
736 |
+
]
|
737 |
+
)
|
738 |
+
|
739 |
+
if add_space_attn:
|
740 |
+
self.space_blocks = nn.ModuleList(
|
741 |
+
[
|
742 |
+
AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, flash=flash)
|
743 |
+
for _ in range(space_depth)
|
744 |
+
]
|
745 |
+
)
|
746 |
+
assert len(self.time_blocks) >= len(self.space_blocks)
|
747 |
+
|
748 |
+
# Placeholder for the rigid transformation
|
749 |
+
self.RigidProj = nn.Linear(self.hidden_size, 128, bias=True)
|
750 |
+
self.Proj = nn.Linear(self.hidden_size, 128, bias=True)
|
751 |
+
|
752 |
+
self.se3_dec = nn.Linear(384, 3, bias=True)
|
753 |
+
self.initialize_weights()
|
754 |
+
|
755 |
+
def initialize_weights(self):
|
756 |
+
def _basic_init(module):
|
757 |
+
if isinstance(module, nn.Linear):
|
758 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
759 |
+
if module.bias is not None:
|
760 |
+
nn.init.constant_(module.bias, 0)
|
761 |
+
|
762 |
+
self.apply(_basic_init)
|
763 |
+
|
764 |
+
def forward(self, input_tensor, se3_feature):
|
765 |
+
""" Updating with Transformer
|
766 |
+
|
767 |
+
Args:
|
768 |
+
input_tensor: B, N, T, C
|
769 |
+
arap_embed: B, N, T, C
|
770 |
+
"""
|
771 |
+
B, N, T, C = input_tensor.shape
|
772 |
+
x = self.input_transform(input_tensor)
|
773 |
+
tokens = x
|
774 |
+
K = 0
|
775 |
+
j = 0
|
776 |
+
for i in range(len(self.time_blocks)):
|
777 |
+
tokens_time = rearrange(tokens, "b n t c -> (b n) t c", b=B, t=T, n=N+K)
|
778 |
+
tokens_time = self.time_blocks[i](tokens_time)
|
779 |
+
tokens = rearrange(tokens_time, "(b n) t c -> b n t c ", b=B, t=T, n=N+K)
|
780 |
+
if self.add_space_attn and (
|
781 |
+
i % (len(self.time_blocks) // len(self.space_blocks)) == 0
|
782 |
+
):
|
783 |
+
tokens_space = rearrange(tokens, "b n t c -> (b t) n c ", b=B, t=T, n=N)
|
784 |
+
tokens_space = self.space_blocks[j](tokens_space)
|
785 |
+
tokens = rearrange(tokens_space, "(b t) n c -> b n t c ", b=B, t=T, n=N)
|
786 |
+
j += 1
|
787 |
+
|
788 |
+
B, N, S, _ = tokens.shape
|
789 |
+
feat0, feat1 = self.gnn(tokens.view(B*N*S, -1)[None,...], se3_feature[None, ...])
|
790 |
+
|
791 |
+
so3 = F.tanh(self.se3_dec(feat0.view(B*N*S, -1)[None,...].view(B, N, S, -1))/100)
|
792 |
+
flow = self.flow_head(feat0.view(B,N,S,-1))
|
793 |
+
|
794 |
+
return flow, _, _, feat1, so3
|
795 |
+
|
796 |
+
|
797 |
+
class FusionFormer(nn.Module):
|
798 |
+
"""
|
799 |
+
Fuse the feature tracks info with the low rank motion tokens
|
800 |
+
"""
|
801 |
+
def __init__(
|
802 |
+
self,
|
803 |
+
d_model=64,
|
804 |
+
nhead=8,
|
805 |
+
attn_iters=4,
|
806 |
+
mlp_ratio=4.0,
|
807 |
+
flash=False,
|
808 |
+
input_dim=35,
|
809 |
+
output_dim=384+3,
|
810 |
+
):
|
811 |
+
super().__init__()
|
812 |
+
self.flash = flash
|
813 |
+
self.in_proj = nn.ModuleList(
|
814 |
+
[
|
815 |
+
nn.Linear(input_dim, d_model)
|
816 |
+
for _ in range(2)
|
817 |
+
]
|
818 |
+
)
|
819 |
+
self.out_proj = nn.Linear(d_model, output_dim, bias=True)
|
820 |
+
self.time_blocks = nn.ModuleList(
|
821 |
+
[
|
822 |
+
CrossAttnBlock(d_model, d_model, nhead, mlp_ratio=mlp_ratio)
|
823 |
+
for _ in range(attn_iters)
|
824 |
+
]
|
825 |
+
)
|
826 |
+
self.space_blocks = nn.ModuleList(
|
827 |
+
[
|
828 |
+
AttnBlock(d_model, nhead, mlp_ratio=mlp_ratio, flash=self.flash)
|
829 |
+
for _ in range(attn_iters)
|
830 |
+
]
|
831 |
+
)
|
832 |
+
|
833 |
+
self.initialize_weights()
|
834 |
+
|
835 |
+
def initialize_weights(self):
|
836 |
+
def _basic_init(module):
|
837 |
+
if isinstance(module, nn.Linear):
|
838 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
839 |
+
if module.bias is not None:
|
840 |
+
nn.init.constant_(module.bias, 0)
|
841 |
+
self.apply(_basic_init)
|
842 |
+
self.out_proj.weight.data.fill_(0)
|
843 |
+
self.out_proj.bias.data.fill_(0)
|
844 |
+
|
845 |
+
def forward(self, x, token_cls):
|
846 |
+
""" Fuse the feature tracks info with the low rank motion tokens
|
847 |
+
|
848 |
+
Args:
|
849 |
+
x: B, S, N, C
|
850 |
+
Traj_whole: B T N C
|
851 |
+
|
852 |
+
"""
|
853 |
+
B, S, N, C = x.shape
|
854 |
+
_, T, _, _ = token_cls.shape
|
855 |
+
x = self.in_proj[0](x)
|
856 |
+
token_cls = self.in_proj[1](token_cls)
|
857 |
+
token_cls = rearrange(token_cls, 'b t n c -> (b n) t c')
|
858 |
+
|
859 |
+
for i in range(len(self.space_blocks)):
|
860 |
+
x = rearrange(x, 'b s n c -> (b n) s c')
|
861 |
+
x = self.time_blocks[i](x, token_cls)
|
862 |
+
x = self.space_blocks[i](x.permute(1,0,2))
|
863 |
+
x = rearrange(x, '(b s) n c -> b s n c', b=B, s=S, n=N)
|
864 |
+
|
865 |
+
x = self.out_proj(x)
|
866 |
+
delta_xyz = x[..., :3]
|
867 |
+
feat_traj = x[..., 3:]
|
868 |
+
return delta_xyz, feat_traj
|
869 |
+
|
870 |
+
class Lie():
|
871 |
+
"""
|
872 |
+
Lie algebra for SO(3) and SE(3) operations in PyTorch
|
873 |
+
"""
|
874 |
+
|
875 |
+
def so3_to_SO3(self,w): # [...,3]
|
876 |
+
wx = self.skew_symmetric(w)
|
877 |
+
theta = w.norm(dim=-1)[...,None,None]
|
878 |
+
I = torch.eye(3,device=w.device,dtype=torch.float32)
|
879 |
+
A = self.taylor_A(theta)
|
880 |
+
B = self.taylor_B(theta)
|
881 |
+
R = I+A*wx+B*wx@wx
|
882 |
+
return R
|
883 |
+
|
884 |
+
def SO3_to_so3(self,R,eps=1e-7): # [...,3,3]
|
885 |
+
trace = R[...,0,0]+R[...,1,1]+R[...,2,2]
|
886 |
+
theta = ((trace-1)/2).clamp(-1+eps,1-eps).acos_()[...,None,None]%np.pi # ln(R) will explode if theta==pi
|
887 |
+
lnR = 1/(2*self.taylor_A(theta)+1e-8)*(R-R.transpose(-2,-1)) # FIXME: wei-chiu finds it weird
|
888 |
+
w0,w1,w2 = lnR[...,2,1],lnR[...,0,2],lnR[...,1,0]
|
889 |
+
w = torch.stack([w0,w1,w2],dim=-1)
|
890 |
+
return w
|
891 |
+
|
892 |
+
def se3_to_SE3(self,wu): # [...,3]
|
893 |
+
w,u = wu.split([3,3],dim=-1)
|
894 |
+
wx = self.skew_symmetric(w)
|
895 |
+
theta = w.norm(dim=-1)[...,None,None]
|
896 |
+
I = torch.eye(3,device=w.device,dtype=torch.float32)
|
897 |
+
A = self.taylor_A(theta)
|
898 |
+
B = self.taylor_B(theta)
|
899 |
+
C = self.taylor_C(theta)
|
900 |
+
R = I+A*wx+B*wx@wx
|
901 |
+
V = I+B*wx+C*wx@wx
|
902 |
+
Rt = torch.cat([R,(V@u[...,None])],dim=-1)
|
903 |
+
return Rt
|
904 |
+
|
905 |
+
def SE3_to_se3(self,Rt,eps=1e-8): # [...,3,4]
|
906 |
+
R,t = Rt.split([3,1],dim=-1)
|
907 |
+
w = self.SO3_to_so3(R)
|
908 |
+
wx = self.skew_symmetric(w)
|
909 |
+
theta = w.norm(dim=-1)[...,None,None]
|
910 |
+
I = torch.eye(3,device=w.device,dtype=torch.float32)
|
911 |
+
A = self.taylor_A(theta)
|
912 |
+
B = self.taylor_B(theta)
|
913 |
+
invV = I-0.5*wx+(1-A/(2*B))/(theta**2+eps)*wx@wx
|
914 |
+
u = (invV@t)[...,0]
|
915 |
+
wu = torch.cat([w,u],dim=-1)
|
916 |
+
return wu
|
917 |
+
|
918 |
+
def skew_symmetric(self,w):
|
919 |
+
w0,w1,w2 = w.unbind(dim=-1)
|
920 |
+
O = torch.zeros_like(w0)
|
921 |
+
wx = torch.stack([torch.stack([O,-w2,w1],dim=-1),
|
922 |
+
torch.stack([w2,O,-w0],dim=-1),
|
923 |
+
torch.stack([-w1,w0,O],dim=-1)],dim=-2)
|
924 |
+
return wx
|
925 |
+
|
926 |
+
def taylor_A(self,x,nth=10):
|
927 |
+
# Taylor expansion of sin(x)/x
|
928 |
+
ans = torch.zeros_like(x)
|
929 |
+
denom = 1.
|
930 |
+
for i in range(nth+1):
|
931 |
+
if i>0: denom *= (2*i)*(2*i+1)
|
932 |
+
ans = ans+(-1)**i*x**(2*i)/denom
|
933 |
+
return ans
|
934 |
+
def taylor_B(self,x,nth=10):
|
935 |
+
# Taylor expansion of (1-cos(x))/x**2
|
936 |
+
ans = torch.zeros_like(x)
|
937 |
+
denom = 1.
|
938 |
+
for i in range(nth+1):
|
939 |
+
denom *= (2*i+1)*(2*i+2)
|
940 |
+
ans = ans+(-1)**i*x**(2*i)/denom
|
941 |
+
return ans
|
942 |
+
def taylor_C(self,x,nth=10):
|
943 |
+
# Taylor expansion of (x-sin(x))/x**3
|
944 |
+
ans = torch.zeros_like(x)
|
945 |
+
denom = 1.
|
946 |
+
for i in range(nth+1):
|
947 |
+
denom *= (2*i+2)*(2*i+3)
|
948 |
+
ans = ans+(-1)**i*x**(2*i)/denom
|
949 |
+
return ans
|
950 |
+
|
951 |
+
|
952 |
+
|
953 |
+
def pix2cam(coords,
|
954 |
+
intr):
|
955 |
+
"""
|
956 |
+
Args:
|
957 |
+
coords: [B, T, N, 3]
|
958 |
+
intr: [B, T, 3, 3]
|
959 |
+
"""
|
960 |
+
coords=coords.detach()
|
961 |
+
B, S, N, _, = coords.shape
|
962 |
+
xy_src = coords.reshape(B*S*N, 3)
|
963 |
+
intr = intr[:, :, None, ...].repeat(1, 1, N, 1, 1).reshape(B*S*N, 3, 3)
|
964 |
+
xy_src = torch.cat([xy_src[..., :2], torch.ones_like(xy_src[..., :1])], dim=-1)
|
965 |
+
xyz_src = (torch.inverse(intr)@xy_src[...,None])[...,0]
|
966 |
+
dp_pred = coords[..., 2]
|
967 |
+
xyz_src_ = (xyz_src*(dp_pred.reshape(S*N, 1)))
|
968 |
+
xyz_src_ = xyz_src_.reshape(B, S, N, 3)
|
969 |
+
return xyz_src_
|
970 |
+
|
971 |
+
def cam2pix(coords,
|
972 |
+
intr):
|
973 |
+
"""
|
974 |
+
Args:
|
975 |
+
coords: [B, T, N, 3]
|
976 |
+
intr: [B, T, 3, 3]
|
977 |
+
"""
|
978 |
+
coords=coords.detach()
|
979 |
+
B, S, N, _, = coords.shape
|
980 |
+
xy_src = coords.reshape(B*S*N, 3).clone()
|
981 |
+
intr = intr[:, :, None, ...].repeat(1, 1, N, 1, 1).reshape(B*S*N, 3, 3)
|
982 |
+
xy_src = xy_src / (xy_src[..., 2:]+1e-5)
|
983 |
+
xyz_src = (intr@xy_src[...,None])[...,0]
|
984 |
+
dp_pred = coords[..., 2]
|
985 |
+
xyz_src[...,2] *= dp_pred.reshape(S*N)
|
986 |
+
xyz_src = xyz_src.reshape(B, S, N, 3)
|
987 |
+
return xyz_src
|
988 |
+
|
989 |
+
def edgeMat(traj3d):
|
990 |
+
"""
|
991 |
+
Args:
|
992 |
+
traj3d: [B, T, N, 3]
|
993 |
+
"""
|
994 |
+
B, T, N, _ = traj3d.shape
|
995 |
+
traj3d = traj3d
|
996 |
+
traj3d = traj3d.view(B, T, N, 3)
|
997 |
+
traj3d = traj3d[..., None, :] - traj3d[..., None, :, :] # B, T, N, N, 3
|
998 |
+
edgeMat = traj3d.norm(dim=-1) # B, T, N, N
|
999 |
+
return edgeMat
|
models/spatracker/models/core/spatracker/dpt/__init__.py
ADDED
File without changes
|
models/spatracker/models/core/spatracker/dpt/base_model.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class BaseModel(torch.nn.Module):
|
5 |
+
def load(self, path):
|
6 |
+
"""Load model from file.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
path (str): file path
|
10 |
+
"""
|
11 |
+
parameters = torch.load(path, map_location=torch.device("cpu"))
|
12 |
+
|
13 |
+
if "optimizer" in parameters:
|
14 |
+
parameters = parameters["model"]
|
15 |
+
|
16 |
+
self.load_state_dict(parameters)
|
models/spatracker/models/core/spatracker/dpt/blocks.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from models.spatracker.models.core.spatracker.dpt.vit import (
|
5 |
+
_make_pretrained_vitb_rn50_384,
|
6 |
+
_make_pretrained_vitl16_384,
|
7 |
+
_make_pretrained_vitb16_384,
|
8 |
+
forward_vit,
|
9 |
+
_make_pretrained_vit_tiny
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
def _make_encoder(
|
14 |
+
backbone,
|
15 |
+
features,
|
16 |
+
use_pretrained,
|
17 |
+
groups=1,
|
18 |
+
expand=False,
|
19 |
+
exportable=True,
|
20 |
+
hooks=None,
|
21 |
+
use_vit_only=False,
|
22 |
+
use_readout="ignore",
|
23 |
+
enable_attention_hooks=False,
|
24 |
+
):
|
25 |
+
if backbone == "vitl16_384":
|
26 |
+
pretrained = _make_pretrained_vitl16_384(
|
27 |
+
use_pretrained,
|
28 |
+
hooks=hooks,
|
29 |
+
use_readout=use_readout,
|
30 |
+
enable_attention_hooks=enable_attention_hooks,
|
31 |
+
)
|
32 |
+
scratch = _make_scratch(
|
33 |
+
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
34 |
+
) # ViT-L/16 - 85.0% Top1 (backbone)
|
35 |
+
elif backbone == "vitb_rn50_384":
|
36 |
+
pretrained = _make_pretrained_vitb_rn50_384(
|
37 |
+
use_pretrained,
|
38 |
+
hooks=hooks,
|
39 |
+
use_vit_only=use_vit_only,
|
40 |
+
use_readout=use_readout,
|
41 |
+
enable_attention_hooks=enable_attention_hooks,
|
42 |
+
)
|
43 |
+
scratch = _make_scratch(
|
44 |
+
[256, 512, 768, 768], features, groups=groups, expand=expand
|
45 |
+
) # ViT-H/16 - 85.0% Top1 (backbone)
|
46 |
+
elif backbone == "vitb16_384":
|
47 |
+
pretrained = _make_pretrained_vitb16_384(
|
48 |
+
use_pretrained,
|
49 |
+
hooks=hooks,
|
50 |
+
use_readout=use_readout,
|
51 |
+
enable_attention_hooks=enable_attention_hooks,
|
52 |
+
)
|
53 |
+
scratch = _make_scratch(
|
54 |
+
[96, 192, 384, 768], features, groups=groups, expand=expand
|
55 |
+
) # ViT-B/16 - 84.6% Top1 (backbone)
|
56 |
+
elif backbone == "resnext101_wsl":
|
57 |
+
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
|
58 |
+
scratch = _make_scratch(
|
59 |
+
[256, 512, 1024, 2048], features, groups=groups, expand=expand
|
60 |
+
) # efficientnet_lite3
|
61 |
+
elif backbone == "vit_tiny_r_s16_p8_384":
|
62 |
+
pretrained = _make_pretrained_vit_tiny(
|
63 |
+
use_pretrained,
|
64 |
+
hooks=hooks,
|
65 |
+
use_readout=use_readout,
|
66 |
+
enable_attention_hooks=enable_attention_hooks,
|
67 |
+
)
|
68 |
+
scratch = _make_scratch(
|
69 |
+
[96, 192, 384, 768], features, groups=groups, expand=expand
|
70 |
+
)
|
71 |
+
else:
|
72 |
+
print(f"Backbone '{backbone}' not implemented")
|
73 |
+
assert False
|
74 |
+
|
75 |
+
return pretrained, scratch
|
76 |
+
|
77 |
+
|
78 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
79 |
+
scratch = nn.Module()
|
80 |
+
|
81 |
+
out_shape1 = out_shape
|
82 |
+
out_shape2 = out_shape
|
83 |
+
out_shape3 = out_shape
|
84 |
+
out_shape4 = out_shape
|
85 |
+
if expand == True:
|
86 |
+
out_shape1 = out_shape
|
87 |
+
out_shape2 = out_shape * 2
|
88 |
+
out_shape3 = out_shape * 4
|
89 |
+
out_shape4 = out_shape * 8
|
90 |
+
|
91 |
+
scratch.layer1_rn = nn.Conv2d(
|
92 |
+
in_shape[0],
|
93 |
+
out_shape1,
|
94 |
+
kernel_size=3,
|
95 |
+
stride=1,
|
96 |
+
padding=1,
|
97 |
+
bias=False,
|
98 |
+
groups=groups,
|
99 |
+
)
|
100 |
+
scratch.layer2_rn = nn.Conv2d(
|
101 |
+
in_shape[1],
|
102 |
+
out_shape2,
|
103 |
+
kernel_size=3,
|
104 |
+
stride=1,
|
105 |
+
padding=1,
|
106 |
+
bias=False,
|
107 |
+
groups=groups,
|
108 |
+
)
|
109 |
+
scratch.layer3_rn = nn.Conv2d(
|
110 |
+
in_shape[2],
|
111 |
+
out_shape3,
|
112 |
+
kernel_size=3,
|
113 |
+
stride=1,
|
114 |
+
padding=1,
|
115 |
+
bias=False,
|
116 |
+
groups=groups,
|
117 |
+
)
|
118 |
+
scratch.layer4_rn = nn.Conv2d(
|
119 |
+
in_shape[3],
|
120 |
+
out_shape4,
|
121 |
+
kernel_size=3,
|
122 |
+
stride=1,
|
123 |
+
padding=1,
|
124 |
+
bias=False,
|
125 |
+
groups=groups,
|
126 |
+
)
|
127 |
+
|
128 |
+
return scratch
|
129 |
+
|
130 |
+
|
131 |
+
def _make_resnet_backbone(resnet):
|
132 |
+
pretrained = nn.Module()
|
133 |
+
pretrained.layer1 = nn.Sequential(
|
134 |
+
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
|
135 |
+
)
|
136 |
+
|
137 |
+
pretrained.layer2 = resnet.layer2
|
138 |
+
pretrained.layer3 = resnet.layer3
|
139 |
+
pretrained.layer4 = resnet.layer4
|
140 |
+
|
141 |
+
return pretrained
|
142 |
+
|
143 |
+
|
144 |
+
def _make_pretrained_resnext101_wsl(use_pretrained):
|
145 |
+
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
|
146 |
+
return _make_resnet_backbone(resnet)
|
147 |
+
|
148 |
+
|
149 |
+
class Interpolate(nn.Module):
|
150 |
+
"""Interpolation module."""
|
151 |
+
|
152 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
153 |
+
"""Init.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
scale_factor (float): scaling
|
157 |
+
mode (str): interpolation mode
|
158 |
+
"""
|
159 |
+
super(Interpolate, self).__init__()
|
160 |
+
|
161 |
+
self.interp = nn.functional.interpolate
|
162 |
+
self.scale_factor = scale_factor
|
163 |
+
self.mode = mode
|
164 |
+
self.align_corners = align_corners
|
165 |
+
|
166 |
+
def forward(self, x):
|
167 |
+
"""Forward pass.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
x (tensor): input
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
tensor: interpolated data
|
174 |
+
"""
|
175 |
+
|
176 |
+
x = self.interp(
|
177 |
+
x,
|
178 |
+
scale_factor=self.scale_factor,
|
179 |
+
mode=self.mode,
|
180 |
+
align_corners=self.align_corners,
|
181 |
+
)
|
182 |
+
|
183 |
+
return x
|
184 |
+
|
185 |
+
|
186 |
+
class ResidualConvUnit(nn.Module):
|
187 |
+
"""Residual convolution module."""
|
188 |
+
|
189 |
+
def __init__(self, features):
|
190 |
+
"""Init.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
features (int): number of features
|
194 |
+
"""
|
195 |
+
super().__init__()
|
196 |
+
|
197 |
+
self.conv1 = nn.Conv2d(
|
198 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
199 |
+
)
|
200 |
+
|
201 |
+
self.conv2 = nn.Conv2d(
|
202 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
203 |
+
)
|
204 |
+
|
205 |
+
self.relu = nn.ReLU(inplace=True)
|
206 |
+
|
207 |
+
def forward(self, x):
|
208 |
+
"""Forward pass.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
x (tensor): input
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
tensor: output
|
215 |
+
"""
|
216 |
+
out = self.relu(x)
|
217 |
+
out = self.conv1(out)
|
218 |
+
out = self.relu(out)
|
219 |
+
out = self.conv2(out)
|
220 |
+
|
221 |
+
return out + x
|
222 |
+
|
223 |
+
|
224 |
+
class FeatureFusionBlock(nn.Module):
|
225 |
+
"""Feature fusion block."""
|
226 |
+
|
227 |
+
def __init__(self, features):
|
228 |
+
"""Init.
|
229 |
+
|
230 |
+
Args:
|
231 |
+
features (int): number of features
|
232 |
+
"""
|
233 |
+
super(FeatureFusionBlock, self).__init__()
|
234 |
+
|
235 |
+
self.resConfUnit1 = ResidualConvUnit(features)
|
236 |
+
self.resConfUnit2 = ResidualConvUnit(features)
|
237 |
+
|
238 |
+
def forward(self, *xs):
|
239 |
+
"""Forward pass.
|
240 |
+
|
241 |
+
Returns:
|
242 |
+
tensor: output
|
243 |
+
"""
|
244 |
+
output = xs[0]
|
245 |
+
|
246 |
+
if len(xs) == 2:
|
247 |
+
output += self.resConfUnit1(xs[1])
|
248 |
+
|
249 |
+
output = self.resConfUnit2(output)
|
250 |
+
|
251 |
+
output = nn.functional.interpolate(
|
252 |
+
output, scale_factor=2, mode="bilinear", align_corners=True
|
253 |
+
)
|
254 |
+
|
255 |
+
return output
|
256 |
+
|
257 |
+
|
258 |
+
class ResidualConvUnit_custom(nn.Module):
|
259 |
+
"""Residual convolution module."""
|
260 |
+
|
261 |
+
def __init__(self, features, activation, bn):
|
262 |
+
"""Init.
|
263 |
+
|
264 |
+
Args:
|
265 |
+
features (int): number of features
|
266 |
+
"""
|
267 |
+
super().__init__()
|
268 |
+
|
269 |
+
self.bn = bn
|
270 |
+
|
271 |
+
self.groups = 1
|
272 |
+
|
273 |
+
self.conv1 = nn.Conv2d(
|
274 |
+
features,
|
275 |
+
features,
|
276 |
+
kernel_size=3,
|
277 |
+
stride=1,
|
278 |
+
padding=1,
|
279 |
+
bias=not self.bn,
|
280 |
+
groups=self.groups,
|
281 |
+
)
|
282 |
+
|
283 |
+
self.conv2 = nn.Conv2d(
|
284 |
+
features,
|
285 |
+
features,
|
286 |
+
kernel_size=3,
|
287 |
+
stride=1,
|
288 |
+
padding=1,
|
289 |
+
bias=not self.bn,
|
290 |
+
groups=self.groups,
|
291 |
+
)
|
292 |
+
|
293 |
+
if self.bn == True:
|
294 |
+
self.bn1 = nn.BatchNorm2d(features)
|
295 |
+
self.bn2 = nn.BatchNorm2d(features)
|
296 |
+
|
297 |
+
self.activation = activation
|
298 |
+
|
299 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
300 |
+
|
301 |
+
def forward(self, x):
|
302 |
+
"""Forward pass.
|
303 |
+
|
304 |
+
Args:
|
305 |
+
x (tensor): input
|
306 |
+
|
307 |
+
Returns:
|
308 |
+
tensor: output
|
309 |
+
"""
|
310 |
+
|
311 |
+
out = self.activation(x)
|
312 |
+
out = self.conv1(out)
|
313 |
+
if self.bn == True:
|
314 |
+
out = self.bn1(out)
|
315 |
+
|
316 |
+
out = self.activation(out)
|
317 |
+
out = self.conv2(out)
|
318 |
+
if self.bn == True:
|
319 |
+
out = self.bn2(out)
|
320 |
+
|
321 |
+
if self.groups > 1:
|
322 |
+
out = self.conv_merge(out)
|
323 |
+
|
324 |
+
return self.skip_add.add(out, x)
|
325 |
+
|
326 |
+
# return out + x
|
327 |
+
|
328 |
+
|
329 |
+
class FeatureFusionBlock_custom(nn.Module):
|
330 |
+
"""Feature fusion block."""
|
331 |
+
|
332 |
+
def __init__(
|
333 |
+
self,
|
334 |
+
features,
|
335 |
+
activation,
|
336 |
+
deconv=False,
|
337 |
+
bn=False,
|
338 |
+
expand=False,
|
339 |
+
align_corners=True,
|
340 |
+
):
|
341 |
+
"""Init.
|
342 |
+
|
343 |
+
Args:
|
344 |
+
features (int): number of features
|
345 |
+
"""
|
346 |
+
super(FeatureFusionBlock_custom, self).__init__()
|
347 |
+
|
348 |
+
self.deconv = deconv
|
349 |
+
self.align_corners = align_corners
|
350 |
+
|
351 |
+
self.groups = 1
|
352 |
+
|
353 |
+
self.expand = expand
|
354 |
+
out_features = features
|
355 |
+
if self.expand == True:
|
356 |
+
out_features = features // 2
|
357 |
+
|
358 |
+
self.out_conv = nn.Conv2d(
|
359 |
+
features,
|
360 |
+
out_features,
|
361 |
+
kernel_size=1,
|
362 |
+
stride=1,
|
363 |
+
padding=0,
|
364 |
+
bias=True,
|
365 |
+
groups=1,
|
366 |
+
)
|
367 |
+
|
368 |
+
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
369 |
+
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
370 |
+
|
371 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
372 |
+
|
373 |
+
def forward(self, *xs):
|
374 |
+
"""Forward pass.
|
375 |
+
|
376 |
+
Returns:
|
377 |
+
tensor: output
|
378 |
+
"""
|
379 |
+
output = xs[0]
|
380 |
+
|
381 |
+
if len(xs) == 2:
|
382 |
+
res = self.resConfUnit1(xs[1])
|
383 |
+
output = self.skip_add.add(output, res)
|
384 |
+
# output += res
|
385 |
+
|
386 |
+
output = self.resConfUnit2(output)
|
387 |
+
|
388 |
+
output = nn.functional.interpolate(
|
389 |
+
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
|
390 |
+
)
|
391 |
+
|
392 |
+
output = self.out_conv(output)
|
393 |
+
|
394 |
+
return output
|
models/spatracker/models/core/spatracker/dpt/midas_net.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
2 |
+
This file contains code that is adapted from
|
3 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from models.spatracker.models.core.spatracker.dpt.base_model import BaseModel
|
9 |
+
from models.spatracker.models.core.spatracker.dpt.blocks import FeatureFusionBlock, Interpolate, _make_encoder
|
10 |
+
|
11 |
+
|
12 |
+
class MidasNet_large(BaseModel):
|
13 |
+
"""Network for monocular depth estimation."""
|
14 |
+
|
15 |
+
def __init__(self, path=None, features=256, non_negative=True):
|
16 |
+
"""Init.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
path (str, optional): Path to saved model. Defaults to None.
|
20 |
+
features (int, optional): Number of features. Defaults to 256.
|
21 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
22 |
+
"""
|
23 |
+
print("Loading weights: ", path)
|
24 |
+
|
25 |
+
super(MidasNet_large, self).__init__()
|
26 |
+
|
27 |
+
use_pretrained = False if path is None else True
|
28 |
+
|
29 |
+
self.pretrained, self.scratch = _make_encoder(
|
30 |
+
backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained
|
31 |
+
)
|
32 |
+
|
33 |
+
self.scratch.refinenet4 = FeatureFusionBlock(features)
|
34 |
+
self.scratch.refinenet3 = FeatureFusionBlock(features)
|
35 |
+
self.scratch.refinenet2 = FeatureFusionBlock(features)
|
36 |
+
self.scratch.refinenet1 = FeatureFusionBlock(features)
|
37 |
+
|
38 |
+
self.scratch.output_conv = nn.Sequential(
|
39 |
+
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
|
40 |
+
Interpolate(scale_factor=2, mode="bilinear"),
|
41 |
+
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
|
42 |
+
nn.ReLU(True),
|
43 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
44 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
45 |
+
)
|
46 |
+
|
47 |
+
if path:
|
48 |
+
self.load(path)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
"""Forward pass.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
x (tensor): input data (image)
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
tensor: depth
|
58 |
+
"""
|
59 |
+
|
60 |
+
layer_1 = self.pretrained.layer1(x)
|
61 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
62 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
63 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
64 |
+
|
65 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
66 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
67 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
68 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
69 |
+
|
70 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
71 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
72 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
73 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
74 |
+
|
75 |
+
out = self.scratch.output_conv(path_1)
|
76 |
+
|
77 |
+
return torch.squeeze(out, dim=1)
|
models/spatracker/models/core/spatracker/dpt/models.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from models.spatracker.models.core.spatracker.dpt.base_model import BaseModel
|
6 |
+
from models.spatracker.models.core.spatracker.dpt.blocks import (
|
7 |
+
FeatureFusionBlock,
|
8 |
+
FeatureFusionBlock_custom,
|
9 |
+
Interpolate,
|
10 |
+
_make_encoder,
|
11 |
+
forward_vit,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
def _make_fusion_block(features, use_bn):
|
16 |
+
return FeatureFusionBlock_custom(
|
17 |
+
features,
|
18 |
+
nn.ReLU(False),
|
19 |
+
deconv=False,
|
20 |
+
bn=use_bn,
|
21 |
+
expand=False,
|
22 |
+
align_corners=True,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class DPT(BaseModel):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
head,
|
30 |
+
features=256,
|
31 |
+
backbone="vitb_rn50_384",
|
32 |
+
readout="project",
|
33 |
+
channels_last=False,
|
34 |
+
use_bn=True,
|
35 |
+
enable_attention_hooks=False,
|
36 |
+
):
|
37 |
+
|
38 |
+
super(DPT, self).__init__()
|
39 |
+
|
40 |
+
self.channels_last = channels_last
|
41 |
+
|
42 |
+
hooks = {
|
43 |
+
"vitb_rn50_384": [0, 1, 8, 11],
|
44 |
+
"vitb16_384": [2, 5, 8, 11],
|
45 |
+
"vitl16_384": [5, 11, 17, 23],
|
46 |
+
"vit_tiny_r_s16_p8_384": [0, 1, 2, 3],
|
47 |
+
}
|
48 |
+
|
49 |
+
# Instantiate backbone and reassemble blocks
|
50 |
+
self.pretrained, self.scratch = _make_encoder(
|
51 |
+
backbone,
|
52 |
+
features,
|
53 |
+
False, # Set to true of you want to train from scratch, uses ImageNet weights
|
54 |
+
groups=1,
|
55 |
+
expand=False,
|
56 |
+
exportable=False,
|
57 |
+
hooks=hooks[backbone],
|
58 |
+
use_readout=readout,
|
59 |
+
enable_attention_hooks=enable_attention_hooks,
|
60 |
+
)
|
61 |
+
|
62 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
63 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
64 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
65 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
66 |
+
|
67 |
+
self.scratch.output_conv = head
|
68 |
+
|
69 |
+
self.proj_out = nn.Sequential(
|
70 |
+
nn.Conv2d(
|
71 |
+
256+512+384+384,
|
72 |
+
256,
|
73 |
+
kernel_size=3,
|
74 |
+
padding=1,
|
75 |
+
padding_mode="zeros",
|
76 |
+
),
|
77 |
+
nn.BatchNorm2d(128 * 2),
|
78 |
+
nn.ReLU(True),
|
79 |
+
nn.Conv2d(
|
80 |
+
128 * 2,
|
81 |
+
128,
|
82 |
+
kernel_size=3,
|
83 |
+
padding=1,
|
84 |
+
padding_mode="zeros",
|
85 |
+
)
|
86 |
+
)
|
87 |
+
|
88 |
+
|
89 |
+
def forward(self, x, only_enc=False):
|
90 |
+
if self.channels_last == True:
|
91 |
+
x.contiguous(memory_format=torch.channels_last)
|
92 |
+
if only_enc:
|
93 |
+
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
|
94 |
+
a = (layer_1)
|
95 |
+
b = (
|
96 |
+
F.interpolate(
|
97 |
+
layer_2,
|
98 |
+
scale_factor=2,
|
99 |
+
mode="bilinear",
|
100 |
+
align_corners=True,
|
101 |
+
)
|
102 |
+
)
|
103 |
+
c = (
|
104 |
+
F.interpolate(
|
105 |
+
layer_3,
|
106 |
+
scale_factor=8,
|
107 |
+
mode="bilinear",
|
108 |
+
align_corners=True,
|
109 |
+
)
|
110 |
+
)
|
111 |
+
d = (
|
112 |
+
F.interpolate(
|
113 |
+
layer_4,
|
114 |
+
scale_factor=16,
|
115 |
+
mode="bilinear",
|
116 |
+
align_corners=True,
|
117 |
+
)
|
118 |
+
)
|
119 |
+
x = self.proj_out(torch.cat([a, b, c, d], dim=1))
|
120 |
+
return x
|
121 |
+
else:
|
122 |
+
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
|
123 |
+
|
124 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
125 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
126 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
127 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
128 |
+
|
129 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
130 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
131 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
132 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
133 |
+
|
134 |
+
_,_,H_out,W_out = path_1.size()
|
135 |
+
path_2_up = F.interpolate(path_2, size=(H_out,W_out), mode="bilinear", align_corners=True)
|
136 |
+
path_3_up = F.interpolate(path_3, size=(H_out,W_out), mode="bilinear", align_corners=True)
|
137 |
+
path_4_up = F.interpolate(path_4, size=(H_out,W_out), mode="bilinear", align_corners=True)
|
138 |
+
|
139 |
+
out = self.scratch.output_conv(path_1+path_2_up+path_3_up+path_4_up)
|
140 |
+
|
141 |
+
return out
|
142 |
+
|
143 |
+
|
144 |
+
class DPTDepthModel(DPT):
|
145 |
+
def __init__(
|
146 |
+
self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=False, **kwargs
|
147 |
+
):
|
148 |
+
features = kwargs["features"] if "features" in kwargs else 256
|
149 |
+
|
150 |
+
self.scale = scale
|
151 |
+
self.shift = shift
|
152 |
+
self.invert = invert
|
153 |
+
|
154 |
+
head = nn.Sequential(
|
155 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
156 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
157 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
158 |
+
nn.ReLU(True),
|
159 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
160 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
161 |
+
nn.Identity(),
|
162 |
+
)
|
163 |
+
|
164 |
+
super().__init__(head, **kwargs)
|
165 |
+
|
166 |
+
if path is not None:
|
167 |
+
self.load(path)
|
168 |
+
|
169 |
+
def forward(self, x):
|
170 |
+
inv_depth = super().forward(x).squeeze(dim=1)
|
171 |
+
|
172 |
+
if self.invert:
|
173 |
+
depth = self.scale * inv_depth + self.shift
|
174 |
+
depth[depth < 1e-8] = 1e-8
|
175 |
+
depth = 1.0 / depth
|
176 |
+
return depth
|
177 |
+
else:
|
178 |
+
return inv_depth
|
179 |
+
|
180 |
+
class DPTEncoder(DPT):
|
181 |
+
def __init__(
|
182 |
+
self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=False, **kwargs
|
183 |
+
):
|
184 |
+
features = kwargs["features"] if "features" in kwargs else 256
|
185 |
+
|
186 |
+
self.scale = scale
|
187 |
+
self.shift = shift
|
188 |
+
|
189 |
+
head = nn.Sequential(
|
190 |
+
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
|
191 |
+
)
|
192 |
+
|
193 |
+
super().__init__(head, **kwargs)
|
194 |
+
|
195 |
+
if path is not None:
|
196 |
+
self.load(path)
|
197 |
+
|
198 |
+
def forward(self, x):
|
199 |
+
features = super().forward(x, only_enc=True).squeeze(dim=1)
|
200 |
+
|
201 |
+
return features
|
202 |
+
|
203 |
+
|
204 |
+
class DPTSegmentationModel(DPT):
|
205 |
+
def __init__(self, num_classes, path=None, **kwargs):
|
206 |
+
|
207 |
+
features = kwargs["features"] if "features" in kwargs else 256
|
208 |
+
|
209 |
+
kwargs["use_bn"] = True
|
210 |
+
|
211 |
+
head = nn.Sequential(
|
212 |
+
nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
|
213 |
+
nn.BatchNorm2d(features),
|
214 |
+
nn.ReLU(True),
|
215 |
+
nn.Dropout(0.1, False),
|
216 |
+
nn.Conv2d(features, num_classes, kernel_size=1),
|
217 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
218 |
+
)
|
219 |
+
|
220 |
+
super().__init__(head, **kwargs)
|
221 |
+
|
222 |
+
self.auxlayer = nn.Sequential(
|
223 |
+
nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
|
224 |
+
nn.BatchNorm2d(features),
|
225 |
+
nn.ReLU(True),
|
226 |
+
nn.Dropout(0.1, False),
|
227 |
+
nn.Conv2d(features, num_classes, kernel_size=1),
|
228 |
+
)
|
229 |
+
|
230 |
+
if path is not None:
|
231 |
+
self.load(path)
|
models/spatracker/models/core/spatracker/dpt/transforms.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import math
|
4 |
+
|
5 |
+
|
6 |
+
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
7 |
+
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
sample (dict): sample
|
11 |
+
size (tuple): image size
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
tuple: new size
|
15 |
+
"""
|
16 |
+
shape = list(sample["disparity"].shape)
|
17 |
+
|
18 |
+
if shape[0] >= size[0] and shape[1] >= size[1]:
|
19 |
+
return sample
|
20 |
+
|
21 |
+
scale = [0, 0]
|
22 |
+
scale[0] = size[0] / shape[0]
|
23 |
+
scale[1] = size[1] / shape[1]
|
24 |
+
|
25 |
+
scale = max(scale)
|
26 |
+
|
27 |
+
shape[0] = math.ceil(scale * shape[0])
|
28 |
+
shape[1] = math.ceil(scale * shape[1])
|
29 |
+
|
30 |
+
# resize
|
31 |
+
sample["image"] = cv2.resize(
|
32 |
+
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
|
33 |
+
)
|
34 |
+
|
35 |
+
sample["disparity"] = cv2.resize(
|
36 |
+
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
|
37 |
+
)
|
38 |
+
sample["mask"] = cv2.resize(
|
39 |
+
sample["mask"].astype(np.float32),
|
40 |
+
tuple(shape[::-1]),
|
41 |
+
interpolation=cv2.INTER_NEAREST,
|
42 |
+
)
|
43 |
+
sample["mask"] = sample["mask"].astype(bool)
|
44 |
+
|
45 |
+
return tuple(shape)
|
46 |
+
|
47 |
+
|
48 |
+
class Resize(object):
|
49 |
+
"""Resize sample to given size (width, height)."""
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
width,
|
54 |
+
height,
|
55 |
+
resize_target=True,
|
56 |
+
keep_aspect_ratio=False,
|
57 |
+
ensure_multiple_of=1,
|
58 |
+
resize_method="lower_bound",
|
59 |
+
image_interpolation_method=cv2.INTER_AREA,
|
60 |
+
):
|
61 |
+
"""Init.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
width (int): desired output width
|
65 |
+
height (int): desired output height
|
66 |
+
resize_target (bool, optional):
|
67 |
+
True: Resize the full sample (image, mask, target).
|
68 |
+
False: Resize image only.
|
69 |
+
Defaults to True.
|
70 |
+
keep_aspect_ratio (bool, optional):
|
71 |
+
True: Keep the aspect ratio of the input sample.
|
72 |
+
Output sample might not have the given width and height, and
|
73 |
+
resize behaviour depends on the parameter 'resize_method'.
|
74 |
+
Defaults to False.
|
75 |
+
ensure_multiple_of (int, optional):
|
76 |
+
Output width and height is constrained to be multiple of this parameter.
|
77 |
+
Defaults to 1.
|
78 |
+
resize_method (str, optional):
|
79 |
+
"lower_bound": Output will be at least as large as the given size.
|
80 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
81 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
82 |
+
Defaults to "lower_bound".
|
83 |
+
"""
|
84 |
+
self.__width = width
|
85 |
+
self.__height = height
|
86 |
+
|
87 |
+
self.__resize_target = resize_target
|
88 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
89 |
+
self.__multiple_of = ensure_multiple_of
|
90 |
+
self.__resize_method = resize_method
|
91 |
+
self.__image_interpolation_method = image_interpolation_method
|
92 |
+
|
93 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
94 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
95 |
+
|
96 |
+
if max_val is not None and y > max_val:
|
97 |
+
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
98 |
+
|
99 |
+
if y < min_val:
|
100 |
+
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
101 |
+
|
102 |
+
return y
|
103 |
+
|
104 |
+
def get_size(self, width, height):
|
105 |
+
# determine new height and width
|
106 |
+
scale_height = self.__height / height
|
107 |
+
scale_width = self.__width / width
|
108 |
+
|
109 |
+
if self.__keep_aspect_ratio:
|
110 |
+
if self.__resize_method == "lower_bound":
|
111 |
+
# scale such that output size is lower bound
|
112 |
+
if scale_width > scale_height:
|
113 |
+
# fit width
|
114 |
+
scale_height = scale_width
|
115 |
+
else:
|
116 |
+
# fit height
|
117 |
+
scale_width = scale_height
|
118 |
+
elif self.__resize_method == "upper_bound":
|
119 |
+
# scale such that output size is upper bound
|
120 |
+
if scale_width < scale_height:
|
121 |
+
# fit width
|
122 |
+
scale_height = scale_width
|
123 |
+
else:
|
124 |
+
# fit height
|
125 |
+
scale_width = scale_height
|
126 |
+
elif self.__resize_method == "minimal":
|
127 |
+
# scale as least as possbile
|
128 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
129 |
+
# fit width
|
130 |
+
scale_height = scale_width
|
131 |
+
else:
|
132 |
+
# fit height
|
133 |
+
scale_width = scale_height
|
134 |
+
else:
|
135 |
+
raise ValueError(
|
136 |
+
f"resize_method {self.__resize_method} not implemented"
|
137 |
+
)
|
138 |
+
|
139 |
+
if self.__resize_method == "lower_bound":
|
140 |
+
new_height = self.constrain_to_multiple_of(
|
141 |
+
scale_height * height, min_val=self.__height
|
142 |
+
)
|
143 |
+
new_width = self.constrain_to_multiple_of(
|
144 |
+
scale_width * width, min_val=self.__width
|
145 |
+
)
|
146 |
+
elif self.__resize_method == "upper_bound":
|
147 |
+
new_height = self.constrain_to_multiple_of(
|
148 |
+
scale_height * height, max_val=self.__height
|
149 |
+
)
|
150 |
+
new_width = self.constrain_to_multiple_of(
|
151 |
+
scale_width * width, max_val=self.__width
|
152 |
+
)
|
153 |
+
elif self.__resize_method == "minimal":
|
154 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
155 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
156 |
+
else:
|
157 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
158 |
+
|
159 |
+
return (new_width, new_height)
|
160 |
+
|
161 |
+
def __call__(self, sample):
|
162 |
+
width, height = self.get_size(
|
163 |
+
sample["image"].shape[1], sample["image"].shape[0]
|
164 |
+
)
|
165 |
+
|
166 |
+
# resize sample
|
167 |
+
sample["image"] = cv2.resize(
|
168 |
+
sample["image"],
|
169 |
+
(width, height),
|
170 |
+
interpolation=self.__image_interpolation_method,
|
171 |
+
)
|
172 |
+
|
173 |
+
if self.__resize_target:
|
174 |
+
if "disparity" in sample:
|
175 |
+
sample["disparity"] = cv2.resize(
|
176 |
+
sample["disparity"],
|
177 |
+
(width, height),
|
178 |
+
interpolation=cv2.INTER_NEAREST,
|
179 |
+
)
|
180 |
+
|
181 |
+
if "depth" in sample:
|
182 |
+
sample["depth"] = cv2.resize(
|
183 |
+
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
|
184 |
+
)
|
185 |
+
|
186 |
+
sample["mask"] = cv2.resize(
|
187 |
+
sample["mask"].astype(np.float32),
|
188 |
+
(width, height),
|
189 |
+
interpolation=cv2.INTER_NEAREST,
|
190 |
+
)
|
191 |
+
sample["mask"] = sample["mask"].astype(bool)
|
192 |
+
|
193 |
+
return sample
|
194 |
+
|
195 |
+
|
196 |
+
class NormalizeImage(object):
|
197 |
+
"""Normlize image by given mean and std."""
|
198 |
+
|
199 |
+
def __init__(self, mean, std):
|
200 |
+
self.__mean = mean
|
201 |
+
self.__std = std
|
202 |
+
|
203 |
+
def __call__(self, sample):
|
204 |
+
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
205 |
+
|
206 |
+
return sample
|
207 |
+
|
208 |
+
|
209 |
+
class PrepareForNet(object):
|
210 |
+
"""Prepare sample for usage as network input."""
|
211 |
+
|
212 |
+
def __init__(self):
|
213 |
+
pass
|
214 |
+
|
215 |
+
def __call__(self, sample):
|
216 |
+
image = np.transpose(sample["image"], (2, 0, 1))
|
217 |
+
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
218 |
+
|
219 |
+
if "mask" in sample:
|
220 |
+
sample["mask"] = sample["mask"].astype(np.float32)
|
221 |
+
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
222 |
+
|
223 |
+
if "disparity" in sample:
|
224 |
+
disparity = sample["disparity"].astype(np.float32)
|
225 |
+
sample["disparity"] = np.ascontiguousarray(disparity)
|
226 |
+
|
227 |
+
if "depth" in sample:
|
228 |
+
depth = sample["depth"].astype(np.float32)
|
229 |
+
sample["depth"] = np.ascontiguousarray(depth)
|
230 |
+
|
231 |
+
return sample
|
models/spatracker/models/core/spatracker/dpt/vit.py
ADDED
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import timm
|
4 |
+
import types
|
5 |
+
import math
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
activations = {}
|
10 |
+
|
11 |
+
|
12 |
+
def get_activation(name):
|
13 |
+
def hook(model, input, output):
|
14 |
+
activations[name] = output
|
15 |
+
|
16 |
+
return hook
|
17 |
+
|
18 |
+
|
19 |
+
attention = {}
|
20 |
+
|
21 |
+
|
22 |
+
def get_attention(name):
|
23 |
+
def hook(module, input, output):
|
24 |
+
x = input[0]
|
25 |
+
B, N, C = x.shape
|
26 |
+
qkv = (
|
27 |
+
module.qkv(x)
|
28 |
+
.reshape(B, N, 3, module.num_heads, C // module.num_heads)
|
29 |
+
.permute(2, 0, 3, 1, 4)
|
30 |
+
)
|
31 |
+
q, k, v = (
|
32 |
+
qkv[0],
|
33 |
+
qkv[1],
|
34 |
+
qkv[2],
|
35 |
+
) # make torchscript happy (cannot use tensor as tuple)
|
36 |
+
|
37 |
+
attn = (q @ k.transpose(-2, -1)) * module.scale
|
38 |
+
|
39 |
+
attn = attn.softmax(dim=-1) # [:,:,1,1:]
|
40 |
+
attention[name] = attn
|
41 |
+
|
42 |
+
return hook
|
43 |
+
|
44 |
+
|
45 |
+
def get_mean_attention_map(attn, token, shape):
|
46 |
+
attn = attn[:, :, token, 1:]
|
47 |
+
attn = attn.unflatten(2, torch.Size([shape[2] // 16, shape[3] // 16])).float()
|
48 |
+
attn = torch.nn.functional.interpolate(
|
49 |
+
attn, size=shape[2:], mode="bicubic", align_corners=False
|
50 |
+
).squeeze(0)
|
51 |
+
|
52 |
+
all_attn = torch.mean(attn, 0)
|
53 |
+
|
54 |
+
return all_attn
|
55 |
+
|
56 |
+
|
57 |
+
class Slice(nn.Module):
|
58 |
+
def __init__(self, start_index=1):
|
59 |
+
super(Slice, self).__init__()
|
60 |
+
self.start_index = start_index
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
return x[:, self.start_index :]
|
64 |
+
|
65 |
+
|
66 |
+
class AddReadout(nn.Module):
|
67 |
+
def __init__(self, start_index=1):
|
68 |
+
super(AddReadout, self).__init__()
|
69 |
+
self.start_index = start_index
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
if self.start_index == 2:
|
73 |
+
readout = (x[:, 0] + x[:, 1]) / 2
|
74 |
+
else:
|
75 |
+
readout = x[:, 0]
|
76 |
+
return x[:, self.start_index :] + readout.unsqueeze(1)
|
77 |
+
|
78 |
+
|
79 |
+
class ProjectReadout(nn.Module):
|
80 |
+
def __init__(self, in_features, start_index=1):
|
81 |
+
super(ProjectReadout, self).__init__()
|
82 |
+
self.start_index = start_index
|
83 |
+
|
84 |
+
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
|
88 |
+
features = torch.cat((x[:, self.start_index :], readout), -1)
|
89 |
+
|
90 |
+
return self.project(features)
|
91 |
+
|
92 |
+
|
93 |
+
class Transpose(nn.Module):
|
94 |
+
def __init__(self, dim0, dim1):
|
95 |
+
super(Transpose, self).__init__()
|
96 |
+
self.dim0 = dim0
|
97 |
+
self.dim1 = dim1
|
98 |
+
|
99 |
+
def forward(self, x):
|
100 |
+
x = x.transpose(self.dim0, self.dim1)
|
101 |
+
return x
|
102 |
+
|
103 |
+
|
104 |
+
def forward_vit(pretrained, x):
|
105 |
+
b, c, h, w = x.shape
|
106 |
+
|
107 |
+
glob = pretrained.model.forward_flex(x)
|
108 |
+
|
109 |
+
layer_1 = pretrained.activations["1"]
|
110 |
+
layer_2 = pretrained.activations["2"]
|
111 |
+
layer_3 = pretrained.activations["3"]
|
112 |
+
layer_4 = pretrained.activations["4"]
|
113 |
+
|
114 |
+
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
|
115 |
+
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
|
116 |
+
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
|
117 |
+
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
|
118 |
+
|
119 |
+
unflatten = nn.Sequential(
|
120 |
+
nn.Unflatten(
|
121 |
+
2,
|
122 |
+
torch.Size(
|
123 |
+
[
|
124 |
+
h // pretrained.model.patch_size[1],
|
125 |
+
w // pretrained.model.patch_size[0],
|
126 |
+
]
|
127 |
+
),
|
128 |
+
)
|
129 |
+
)
|
130 |
+
|
131 |
+
if layer_1.ndim == 3:
|
132 |
+
layer_1 = unflatten(layer_1)
|
133 |
+
if layer_2.ndim == 3:
|
134 |
+
layer_2 = unflatten(layer_2)
|
135 |
+
if layer_3.ndim == 3:
|
136 |
+
layer_3 = unflatten(layer_3)
|
137 |
+
if layer_4.ndim == 3:
|
138 |
+
layer_4 = unflatten(layer_4)
|
139 |
+
|
140 |
+
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
|
141 |
+
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
|
142 |
+
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
|
143 |
+
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
|
144 |
+
|
145 |
+
return layer_1, layer_2, layer_3, layer_4
|
146 |
+
|
147 |
+
|
148 |
+
def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
149 |
+
posemb_tok, posemb_grid = (
|
150 |
+
posemb[:, : self.start_index],
|
151 |
+
posemb[0, self.start_index :],
|
152 |
+
)
|
153 |
+
|
154 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
155 |
+
|
156 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
157 |
+
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
|
158 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
159 |
+
|
160 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
161 |
+
|
162 |
+
return posemb
|
163 |
+
|
164 |
+
|
165 |
+
def forward_flex(self, x):
|
166 |
+
b, c, h, w = x.shape
|
167 |
+
|
168 |
+
pos_embed = self._resize_pos_embed(
|
169 |
+
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
|
170 |
+
)
|
171 |
+
|
172 |
+
B = x.shape[0]
|
173 |
+
|
174 |
+
if hasattr(self.patch_embed, "backbone"):
|
175 |
+
x = self.patch_embed.backbone(x)
|
176 |
+
if isinstance(x, (list, tuple)):
|
177 |
+
x = x[-1] # last feature if backbone outputs list/tuple of features
|
178 |
+
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
|
179 |
+
|
180 |
+
if getattr(self, "dist_token", None) is not None:
|
181 |
+
cls_tokens = self.cls_token.expand(
|
182 |
+
B, -1, -1
|
183 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
184 |
+
dist_token = self.dist_token.expand(B, -1, -1)
|
185 |
+
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
186 |
+
else:
|
187 |
+
cls_tokens = self.cls_token.expand(
|
188 |
+
B, -1, -1
|
189 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
190 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
191 |
+
|
192 |
+
x = x + pos_embed
|
193 |
+
x = self.pos_drop(x)
|
194 |
+
|
195 |
+
for blk in self.blocks:
|
196 |
+
x = blk(x)
|
197 |
+
|
198 |
+
x = self.norm(x)
|
199 |
+
|
200 |
+
return x
|
201 |
+
|
202 |
+
|
203 |
+
def get_readout_oper(vit_features, features, use_readout, start_index=1):
|
204 |
+
if use_readout == "ignore":
|
205 |
+
readout_oper = [Slice(start_index)] * len(features)
|
206 |
+
elif use_readout == "add":
|
207 |
+
readout_oper = [AddReadout(start_index)] * len(features)
|
208 |
+
elif use_readout == "project":
|
209 |
+
readout_oper = [
|
210 |
+
ProjectReadout(vit_features, start_index) for out_feat in features
|
211 |
+
]
|
212 |
+
else:
|
213 |
+
assert (
|
214 |
+
False
|
215 |
+
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
|
216 |
+
|
217 |
+
return readout_oper
|
218 |
+
|
219 |
+
|
220 |
+
def _make_vit_b16_backbone(
|
221 |
+
model,
|
222 |
+
features=[96, 192, 384, 768],
|
223 |
+
size=[384, 384],
|
224 |
+
hooks=[2, 5, 8, 11],
|
225 |
+
vit_features=768,
|
226 |
+
use_readout="ignore",
|
227 |
+
start_index=1,
|
228 |
+
enable_attention_hooks=False,
|
229 |
+
):
|
230 |
+
pretrained = nn.Module()
|
231 |
+
|
232 |
+
pretrained.model = model
|
233 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
234 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
235 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
236 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
237 |
+
|
238 |
+
pretrained.activations = activations
|
239 |
+
|
240 |
+
if enable_attention_hooks:
|
241 |
+
pretrained.model.blocks[hooks[0]].attn.register_forward_hook(
|
242 |
+
get_attention("attn_1")
|
243 |
+
)
|
244 |
+
pretrained.model.blocks[hooks[1]].attn.register_forward_hook(
|
245 |
+
get_attention("attn_2")
|
246 |
+
)
|
247 |
+
pretrained.model.blocks[hooks[2]].attn.register_forward_hook(
|
248 |
+
get_attention("attn_3")
|
249 |
+
)
|
250 |
+
pretrained.model.blocks[hooks[3]].attn.register_forward_hook(
|
251 |
+
get_attention("attn_4")
|
252 |
+
)
|
253 |
+
pretrained.attention = attention
|
254 |
+
|
255 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
256 |
+
|
257 |
+
# 32, 48, 136, 384
|
258 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
259 |
+
readout_oper[0],
|
260 |
+
Transpose(1, 2),
|
261 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
262 |
+
nn.Conv2d(
|
263 |
+
in_channels=vit_features,
|
264 |
+
out_channels=features[0],
|
265 |
+
kernel_size=1,
|
266 |
+
stride=1,
|
267 |
+
padding=0,
|
268 |
+
),
|
269 |
+
nn.ConvTranspose2d(
|
270 |
+
in_channels=features[0],
|
271 |
+
out_channels=features[0],
|
272 |
+
kernel_size=4,
|
273 |
+
stride=4,
|
274 |
+
padding=0,
|
275 |
+
bias=True,
|
276 |
+
dilation=1,
|
277 |
+
groups=1,
|
278 |
+
),
|
279 |
+
)
|
280 |
+
|
281 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
282 |
+
readout_oper[1],
|
283 |
+
Transpose(1, 2),
|
284 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
285 |
+
nn.Conv2d(
|
286 |
+
in_channels=vit_features,
|
287 |
+
out_channels=features[1],
|
288 |
+
kernel_size=1,
|
289 |
+
stride=1,
|
290 |
+
padding=0,
|
291 |
+
),
|
292 |
+
nn.ConvTranspose2d(
|
293 |
+
in_channels=features[1],
|
294 |
+
out_channels=features[1],
|
295 |
+
kernel_size=2,
|
296 |
+
stride=2,
|
297 |
+
padding=0,
|
298 |
+
bias=True,
|
299 |
+
dilation=1,
|
300 |
+
groups=1,
|
301 |
+
),
|
302 |
+
)
|
303 |
+
|
304 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
305 |
+
readout_oper[2],
|
306 |
+
Transpose(1, 2),
|
307 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
308 |
+
nn.Conv2d(
|
309 |
+
in_channels=vit_features,
|
310 |
+
out_channels=features[2],
|
311 |
+
kernel_size=1,
|
312 |
+
stride=1,
|
313 |
+
padding=0,
|
314 |
+
),
|
315 |
+
)
|
316 |
+
|
317 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
318 |
+
readout_oper[3],
|
319 |
+
Transpose(1, 2),
|
320 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
321 |
+
nn.Conv2d(
|
322 |
+
in_channels=vit_features,
|
323 |
+
out_channels=features[3],
|
324 |
+
kernel_size=1,
|
325 |
+
stride=1,
|
326 |
+
padding=0,
|
327 |
+
),
|
328 |
+
nn.Conv2d(
|
329 |
+
in_channels=features[3],
|
330 |
+
out_channels=features[3],
|
331 |
+
kernel_size=3,
|
332 |
+
stride=2,
|
333 |
+
padding=1,
|
334 |
+
),
|
335 |
+
)
|
336 |
+
|
337 |
+
pretrained.model.start_index = start_index
|
338 |
+
pretrained.model.patch_size = [16, 16]
|
339 |
+
|
340 |
+
# We inject this function into the VisionTransformer instances so that
|
341 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
342 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
343 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
344 |
+
_resize_pos_embed, pretrained.model
|
345 |
+
)
|
346 |
+
|
347 |
+
return pretrained
|
348 |
+
|
349 |
+
|
350 |
+
def _make_vit_b_rn50_backbone(
|
351 |
+
model,
|
352 |
+
features=[256, 512, 768, 768],
|
353 |
+
size=[384, 384],
|
354 |
+
hooks=[0, 1, 8, 11],
|
355 |
+
vit_features=384,
|
356 |
+
use_vit_only=False,
|
357 |
+
use_readout="ignore",
|
358 |
+
start_index=1,
|
359 |
+
enable_attention_hooks=False,
|
360 |
+
):
|
361 |
+
pretrained = nn.Module()
|
362 |
+
pretrained.model = model
|
363 |
+
pretrained.model.patch_size = [32, 32]
|
364 |
+
ps = pretrained.model.patch_size[0]
|
365 |
+
if use_vit_only == True:
|
366 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
367 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
368 |
+
else:
|
369 |
+
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
|
370 |
+
get_activation("1")
|
371 |
+
)
|
372 |
+
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
|
373 |
+
get_activation("2")
|
374 |
+
)
|
375 |
+
|
376 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
377 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
378 |
+
|
379 |
+
if enable_attention_hooks:
|
380 |
+
pretrained.model.blocks[2].attn.register_forward_hook(get_attention("attn_1"))
|
381 |
+
pretrained.model.blocks[5].attn.register_forward_hook(get_attention("attn_2"))
|
382 |
+
pretrained.model.blocks[8].attn.register_forward_hook(get_attention("attn_3"))
|
383 |
+
pretrained.model.blocks[11].attn.register_forward_hook(get_attention("attn_4"))
|
384 |
+
pretrained.attention = attention
|
385 |
+
|
386 |
+
pretrained.activations = activations
|
387 |
+
|
388 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
389 |
+
|
390 |
+
if use_vit_only == True:
|
391 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
392 |
+
readout_oper[0],
|
393 |
+
Transpose(1, 2),
|
394 |
+
nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])),
|
395 |
+
nn.Conv2d(
|
396 |
+
in_channels=vit_features,
|
397 |
+
out_channels=features[0],
|
398 |
+
kernel_size=1,
|
399 |
+
stride=1,
|
400 |
+
padding=0,
|
401 |
+
),
|
402 |
+
nn.ConvTranspose2d(
|
403 |
+
in_channels=features[0],
|
404 |
+
out_channels=features[0],
|
405 |
+
kernel_size=4,
|
406 |
+
stride=4,
|
407 |
+
padding=0,
|
408 |
+
bias=True,
|
409 |
+
dilation=1,
|
410 |
+
groups=1,
|
411 |
+
),
|
412 |
+
)
|
413 |
+
|
414 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
415 |
+
readout_oper[1],
|
416 |
+
Transpose(1, 2),
|
417 |
+
nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])),
|
418 |
+
nn.Conv2d(
|
419 |
+
in_channels=vit_features,
|
420 |
+
out_channels=features[1],
|
421 |
+
kernel_size=1,
|
422 |
+
stride=1,
|
423 |
+
padding=0,
|
424 |
+
),
|
425 |
+
nn.ConvTranspose2d(
|
426 |
+
in_channels=features[1],
|
427 |
+
out_channels=features[1],
|
428 |
+
kernel_size=2,
|
429 |
+
stride=2,
|
430 |
+
padding=0,
|
431 |
+
bias=True,
|
432 |
+
dilation=1,
|
433 |
+
groups=1,
|
434 |
+
),
|
435 |
+
)
|
436 |
+
else:
|
437 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
438 |
+
nn.Identity(), nn.Identity(), nn.Identity()
|
439 |
+
)
|
440 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
441 |
+
nn.Identity(), nn.Identity(), nn.Identity()
|
442 |
+
)
|
443 |
+
|
444 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
445 |
+
readout_oper[2],
|
446 |
+
Transpose(1, 2),
|
447 |
+
nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])),
|
448 |
+
nn.Conv2d(
|
449 |
+
in_channels=vit_features,
|
450 |
+
out_channels=features[2],
|
451 |
+
kernel_size=1,
|
452 |
+
stride=1,
|
453 |
+
padding=0,
|
454 |
+
),
|
455 |
+
)
|
456 |
+
|
457 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
458 |
+
readout_oper[3],
|
459 |
+
Transpose(1, 2),
|
460 |
+
nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])),
|
461 |
+
nn.Conv2d(
|
462 |
+
in_channels=vit_features,
|
463 |
+
out_channels=features[3],
|
464 |
+
kernel_size=1,
|
465 |
+
stride=1,
|
466 |
+
padding=0,
|
467 |
+
),
|
468 |
+
nn.Conv2d(
|
469 |
+
in_channels=features[3],
|
470 |
+
out_channels=features[3],
|
471 |
+
kernel_size=3,
|
472 |
+
stride=2,
|
473 |
+
padding=1,
|
474 |
+
),
|
475 |
+
)
|
476 |
+
|
477 |
+
pretrained.model.start_index = start_index
|
478 |
+
pretrained.model.patch_size = [32, 32]
|
479 |
+
|
480 |
+
# We inject this function into the VisionTransformer instances so that
|
481 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
482 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
483 |
+
|
484 |
+
# We inject this function into the VisionTransformer instances so that
|
485 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
486 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
487 |
+
_resize_pos_embed, pretrained.model
|
488 |
+
)
|
489 |
+
|
490 |
+
return pretrained
|
491 |
+
|
492 |
+
|
493 |
+
def _make_pretrained_vitb_rn50_384(
|
494 |
+
pretrained,
|
495 |
+
use_readout="ignore",
|
496 |
+
hooks=None,
|
497 |
+
use_vit_only=False,
|
498 |
+
enable_attention_hooks=False,
|
499 |
+
):
|
500 |
+
# model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
|
501 |
+
# model = timm.create_model("vit_tiny_r_s16_p8_384", pretrained=pretrained)
|
502 |
+
model = timm.create_model("vit_small_r26_s32_384", pretrained=pretrained)
|
503 |
+
hooks = [0, 1, 8, 11] if hooks == None else hooks
|
504 |
+
return _make_vit_b_rn50_backbone(
|
505 |
+
model,
|
506 |
+
features=[128, 256, 384, 384],
|
507 |
+
size=[384, 384],
|
508 |
+
hooks=hooks,
|
509 |
+
use_vit_only=use_vit_only,
|
510 |
+
use_readout=use_readout,
|
511 |
+
enable_attention_hooks=enable_attention_hooks,
|
512 |
+
)
|
513 |
+
|
514 |
+
def _make_pretrained_vit_tiny(
|
515 |
+
pretrained,
|
516 |
+
use_readout="ignore",
|
517 |
+
hooks=None,
|
518 |
+
use_vit_only=False,
|
519 |
+
enable_attention_hooks=False,
|
520 |
+
):
|
521 |
+
# model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
|
522 |
+
model = timm.create_model("vit_tiny_r_s16_p8_384", pretrained=pretrained)
|
523 |
+
import ipdb; ipdb.set_trace()
|
524 |
+
hooks = [0, 1, 8, 11] if hooks == None else hooks
|
525 |
+
return _make_vit_tiny_backbone(
|
526 |
+
model,
|
527 |
+
features=[256, 512, 768, 768],
|
528 |
+
size=[384, 384],
|
529 |
+
hooks=hooks,
|
530 |
+
use_vit_only=use_vit_only,
|
531 |
+
use_readout=use_readout,
|
532 |
+
enable_attention_hooks=enable_attention_hooks,
|
533 |
+
)
|
534 |
+
|
535 |
+
def _make_pretrained_vitl16_384(
|
536 |
+
pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
|
537 |
+
):
|
538 |
+
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
|
539 |
+
|
540 |
+
hooks = [5, 11, 17, 23] if hooks == None else hooks
|
541 |
+
return _make_vit_b16_backbone(
|
542 |
+
model,
|
543 |
+
features=[256, 512, 1024, 1024],
|
544 |
+
hooks=hooks,
|
545 |
+
vit_features=1024,
|
546 |
+
use_readout=use_readout,
|
547 |
+
enable_attention_hooks=enable_attention_hooks,
|
548 |
+
)
|
549 |
+
|
550 |
+
|
551 |
+
def _make_pretrained_vitb16_384(
|
552 |
+
pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
|
553 |
+
):
|
554 |
+
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
|
555 |
+
|
556 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
557 |
+
return _make_vit_b16_backbone(
|
558 |
+
model,
|
559 |
+
features=[96, 192, 384, 768],
|
560 |
+
hooks=hooks,
|
561 |
+
use_readout=use_readout,
|
562 |
+
enable_attention_hooks=enable_attention_hooks,
|
563 |
+
)
|
564 |
+
|
565 |
+
|
566 |
+
def _make_pretrained_deitb16_384(
|
567 |
+
pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
|
568 |
+
):
|
569 |
+
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
|
570 |
+
|
571 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
572 |
+
return _make_vit_b16_backbone(
|
573 |
+
model,
|
574 |
+
features=[96, 192, 384, 768],
|
575 |
+
hooks=hooks,
|
576 |
+
use_readout=use_readout,
|
577 |
+
enable_attention_hooks=enable_attention_hooks,
|
578 |
+
)
|
579 |
+
|
580 |
+
|
581 |
+
def _make_pretrained_deitb16_distil_384(
|
582 |
+
pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
|
583 |
+
):
|
584 |
+
model = timm.create_model(
|
585 |
+
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
|
586 |
+
)
|
587 |
+
|
588 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
589 |
+
return _make_vit_b16_backbone(
|
590 |
+
model,
|
591 |
+
features=[96, 192, 384, 768],
|
592 |
+
hooks=hooks,
|
593 |
+
use_readout=use_readout,
|
594 |
+
start_index=2,
|
595 |
+
enable_attention_hooks=enable_attention_hooks,
|
596 |
+
)
|
models/spatracker/models/core/spatracker/feature_net.py
ADDED
@@ -0,0 +1,915 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Adapted from ConvONet
|
3 |
+
https://github.com/autonomousvision/convolutional_occupancy_networks/blob/838bea5b2f1314f2edbb68d05ebb0db49f1f3bd2/src/encoder/pointnet.py#L1
|
4 |
+
"""
|
5 |
+
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
# from torch_scatter import scatter_mean, scatter_max
|
11 |
+
from models.spatracker.models.core.spatracker.unet import UNet
|
12 |
+
from models.spatracker.models.core.model_utils import (
|
13 |
+
vis_PCA
|
14 |
+
)
|
15 |
+
from einops import rearrange
|
16 |
+
|
17 |
+
def compute_iou(occ1, occ2):
|
18 |
+
''' Computes the Intersection over Union (IoU) value for two sets of
|
19 |
+
occupancy values.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
occ1 (tensor): first set of occupancy values
|
23 |
+
occ2 (tensor): second set of occupancy values
|
24 |
+
'''
|
25 |
+
occ1 = np.asarray(occ1)
|
26 |
+
occ2 = np.asarray(occ2)
|
27 |
+
|
28 |
+
# Put all data in second dimension
|
29 |
+
# Also works for 1-dimensional data
|
30 |
+
if occ1.ndim >= 2:
|
31 |
+
occ1 = occ1.reshape(occ1.shape[0], -1)
|
32 |
+
if occ2.ndim >= 2:
|
33 |
+
occ2 = occ2.reshape(occ2.shape[0], -1)
|
34 |
+
|
35 |
+
# Convert to boolean values
|
36 |
+
occ1 = (occ1 >= 0.5)
|
37 |
+
occ2 = (occ2 >= 0.5)
|
38 |
+
|
39 |
+
# Compute IOU
|
40 |
+
area_union = (occ1 | occ2).astype(np.float32).sum(axis=-1)
|
41 |
+
area_intersect = (occ1 & occ2).astype(np.float32).sum(axis=-1)
|
42 |
+
|
43 |
+
iou = (area_intersect / area_union)
|
44 |
+
|
45 |
+
return iou
|
46 |
+
|
47 |
+
|
48 |
+
def chamfer_distance(points1, points2, use_kdtree=True, give_id=False):
|
49 |
+
''' Returns the chamfer distance for the sets of points.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
points1 (numpy array): first point set
|
53 |
+
points2 (numpy array): second point set
|
54 |
+
use_kdtree (bool): whether to use a kdtree
|
55 |
+
give_id (bool): whether to return the IDs of nearest points
|
56 |
+
'''
|
57 |
+
if use_kdtree:
|
58 |
+
return chamfer_distance_kdtree(points1, points2, give_id=give_id)
|
59 |
+
else:
|
60 |
+
return chamfer_distance_naive(points1, points2)
|
61 |
+
|
62 |
+
|
63 |
+
def chamfer_distance_naive(points1, points2):
|
64 |
+
''' Naive implementation of the Chamfer distance.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
points1 (numpy array): first point set
|
68 |
+
points2 (numpy array): second point set
|
69 |
+
'''
|
70 |
+
assert(points1.size() == points2.size())
|
71 |
+
batch_size, T, _ = points1.size()
|
72 |
+
|
73 |
+
points1 = points1.view(batch_size, T, 1, 3)
|
74 |
+
points2 = points2.view(batch_size, 1, T, 3)
|
75 |
+
|
76 |
+
distances = (points1 - points2).pow(2).sum(-1)
|
77 |
+
|
78 |
+
chamfer1 = distances.min(dim=1)[0].mean(dim=1)
|
79 |
+
chamfer2 = distances.min(dim=2)[0].mean(dim=1)
|
80 |
+
|
81 |
+
chamfer = chamfer1 + chamfer2
|
82 |
+
return chamfer
|
83 |
+
|
84 |
+
|
85 |
+
def chamfer_distance_kdtree(points1, points2, give_id=False):
|
86 |
+
''' KD-tree based implementation of the Chamfer distance.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
points1 (numpy array): first point set
|
90 |
+
points2 (numpy array): second point set
|
91 |
+
give_id (bool): whether to return the IDs of the nearest points
|
92 |
+
'''
|
93 |
+
# Points have size batch_size x T x 3
|
94 |
+
batch_size = points1.size(0)
|
95 |
+
|
96 |
+
# First convert points to numpy
|
97 |
+
points1_np = points1.detach().cpu().numpy()
|
98 |
+
points2_np = points2.detach().cpu().numpy()
|
99 |
+
|
100 |
+
# Get list of nearest neighbors indieces
|
101 |
+
idx_nn_12, _ = get_nearest_neighbors_indices_batch(points1_np, points2_np)
|
102 |
+
idx_nn_12 = torch.LongTensor(idx_nn_12).to(points1.device)
|
103 |
+
# Expands it as batch_size x 1 x 3
|
104 |
+
idx_nn_12_expand = idx_nn_12.view(batch_size, -1, 1).expand_as(points1)
|
105 |
+
|
106 |
+
# Get list of nearest neighbors indieces
|
107 |
+
idx_nn_21, _ = get_nearest_neighbors_indices_batch(points2_np, points1_np)
|
108 |
+
idx_nn_21 = torch.LongTensor(idx_nn_21).to(points1.device)
|
109 |
+
# Expands it as batch_size x T x 3
|
110 |
+
idx_nn_21_expand = idx_nn_21.view(batch_size, -1, 1).expand_as(points2)
|
111 |
+
|
112 |
+
# Compute nearest neighbors in points2 to points in points1
|
113 |
+
# points_12[i, j, k] = points2[i, idx_nn_12_expand[i, j, k], k]
|
114 |
+
points_12 = torch.gather(points2, dim=1, index=idx_nn_12_expand)
|
115 |
+
|
116 |
+
# Compute nearest neighbors in points1 to points in points2
|
117 |
+
# points_21[i, j, k] = points2[i, idx_nn_21_expand[i, j, k], k]
|
118 |
+
points_21 = torch.gather(points1, dim=1, index=idx_nn_21_expand)
|
119 |
+
|
120 |
+
# Compute chamfer distance
|
121 |
+
chamfer1 = (points1 - points_12).pow(2).sum(2).mean(1)
|
122 |
+
chamfer2 = (points2 - points_21).pow(2).sum(2).mean(1)
|
123 |
+
|
124 |
+
# Take sum
|
125 |
+
chamfer = chamfer1 + chamfer2
|
126 |
+
|
127 |
+
# If required, also return nearest neighbors
|
128 |
+
if give_id:
|
129 |
+
return chamfer1, chamfer2, idx_nn_12, idx_nn_21
|
130 |
+
|
131 |
+
return chamfer
|
132 |
+
|
133 |
+
|
134 |
+
def get_nearest_neighbors_indices_batch(points_src, points_tgt, k=1):
|
135 |
+
''' Returns the nearest neighbors for point sets batchwise.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
points_src (numpy array): source points
|
139 |
+
points_tgt (numpy array): target points
|
140 |
+
k (int): number of nearest neighbors to return
|
141 |
+
'''
|
142 |
+
indices = []
|
143 |
+
distances = []
|
144 |
+
|
145 |
+
for (p1, p2) in zip(points_src, points_tgt):
|
146 |
+
raise NotImplementedError()
|
147 |
+
# kdtree = KDTree(p2)
|
148 |
+
dist, idx = kdtree.query(p1, k=k)
|
149 |
+
indices.append(idx)
|
150 |
+
distances.append(dist)
|
151 |
+
|
152 |
+
return indices, distances
|
153 |
+
|
154 |
+
|
155 |
+
def make_3d_grid(bb_min, bb_max, shape):
|
156 |
+
''' Makes a 3D grid.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
bb_min (tuple): bounding box minimum
|
160 |
+
bb_max (tuple): bounding box maximum
|
161 |
+
shape (tuple): output shape
|
162 |
+
'''
|
163 |
+
size = shape[0] * shape[1] * shape[2]
|
164 |
+
|
165 |
+
pxs = torch.linspace(bb_min[0], bb_max[0], shape[0])
|
166 |
+
pys = torch.linspace(bb_min[1], bb_max[1], shape[1])
|
167 |
+
pzs = torch.linspace(bb_min[2], bb_max[2], shape[2])
|
168 |
+
|
169 |
+
pxs = pxs.view(-1, 1, 1).expand(*shape).contiguous().view(size)
|
170 |
+
pys = pys.view(1, -1, 1).expand(*shape).contiguous().view(size)
|
171 |
+
pzs = pzs.view(1, 1, -1).expand(*shape).contiguous().view(size)
|
172 |
+
p = torch.stack([pxs, pys, pzs], dim=1)
|
173 |
+
|
174 |
+
return p
|
175 |
+
|
176 |
+
|
177 |
+
def transform_points(points, transform):
|
178 |
+
''' Transforms points with regard to passed camera information.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
points (tensor): points tensor
|
182 |
+
transform (tensor): transformation matrices
|
183 |
+
'''
|
184 |
+
assert(points.size(2) == 3)
|
185 |
+
assert(transform.size(1) == 3)
|
186 |
+
assert(points.size(0) == transform.size(0))
|
187 |
+
|
188 |
+
if transform.size(2) == 4:
|
189 |
+
R = transform[:, :, :3]
|
190 |
+
t = transform[:, :, 3:]
|
191 |
+
points_out = points @ R.transpose(1, 2) + t.transpose(1, 2)
|
192 |
+
elif transform.size(2) == 3:
|
193 |
+
K = transform
|
194 |
+
points_out = points @ K.transpose(1, 2)
|
195 |
+
|
196 |
+
return points_out
|
197 |
+
|
198 |
+
|
199 |
+
def b_inv(b_mat):
|
200 |
+
''' Performs batch matrix inversion.
|
201 |
+
|
202 |
+
Arguments:
|
203 |
+
b_mat: the batch of matrices that should be inverted
|
204 |
+
'''
|
205 |
+
|
206 |
+
eye = b_mat.new_ones(b_mat.size(-1)).diag().expand_as(b_mat)
|
207 |
+
b_inv, _ = torch.gesv(eye, b_mat)
|
208 |
+
return b_inv
|
209 |
+
|
210 |
+
def project_to_camera(points, transform):
|
211 |
+
''' Projects points to the camera plane.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
points (tensor): points tensor
|
215 |
+
transform (tensor): transformation matrices
|
216 |
+
'''
|
217 |
+
p_camera = transform_points(points, transform)
|
218 |
+
p_camera = p_camera[..., :2] / p_camera[..., 2:]
|
219 |
+
return p_camera
|
220 |
+
|
221 |
+
|
222 |
+
def fix_Rt_camera(Rt, loc, scale):
|
223 |
+
''' Fixes Rt camera matrix.
|
224 |
+
|
225 |
+
Args:
|
226 |
+
Rt (tensor): Rt camera matrix
|
227 |
+
loc (tensor): location
|
228 |
+
scale (float): scale
|
229 |
+
'''
|
230 |
+
# Rt is B x 3 x 4
|
231 |
+
# loc is B x 3 and scale is B
|
232 |
+
batch_size = Rt.size(0)
|
233 |
+
R = Rt[:, :, :3]
|
234 |
+
t = Rt[:, :, 3:]
|
235 |
+
|
236 |
+
scale = scale.view(batch_size, 1, 1)
|
237 |
+
R_new = R * scale
|
238 |
+
t_new = t + R @ loc.unsqueeze(2)
|
239 |
+
|
240 |
+
Rt_new = torch.cat([R_new, t_new], dim=2)
|
241 |
+
|
242 |
+
assert(Rt_new.size() == (batch_size, 3, 4))
|
243 |
+
return Rt_new
|
244 |
+
|
245 |
+
def normalize_coordinate(p, padding=0.1, plane='xz'):
|
246 |
+
''' Normalize coordinate to [0, 1] for unit cube experiments
|
247 |
+
|
248 |
+
Args:
|
249 |
+
p (tensor): point
|
250 |
+
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
|
251 |
+
plane (str): plane feature type, ['xz', 'xy', 'yz']
|
252 |
+
'''
|
253 |
+
# breakpoint()
|
254 |
+
if plane == 'xz':
|
255 |
+
xy = p[:, :, [0, 2]]
|
256 |
+
elif plane =='xy':
|
257 |
+
xy = p[:, :, [0, 1]]
|
258 |
+
else:
|
259 |
+
xy = p[:, :, [1, 2]]
|
260 |
+
|
261 |
+
xy = torch.clamp(xy, min=1e-6, max=1. - 1e-6)
|
262 |
+
|
263 |
+
# xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
|
264 |
+
# xy_new = xy_new + 0.5 # range (0, 1)
|
265 |
+
|
266 |
+
# # f there are outliers out of the range
|
267 |
+
# if xy_new.max() >= 1:
|
268 |
+
# xy_new[xy_new >= 1] = 1 - 10e-6
|
269 |
+
# if xy_new.min() < 0:
|
270 |
+
# xy_new[xy_new < 0] = 0.0
|
271 |
+
# xy_new = (xy + 1.) / 2.
|
272 |
+
return xy
|
273 |
+
|
274 |
+
def normalize_3d_coordinate(p, padding=0.1):
|
275 |
+
''' Normalize coordinate to [0, 1] for unit cube experiments.
|
276 |
+
Corresponds to our 3D model
|
277 |
+
|
278 |
+
Args:
|
279 |
+
p (tensor): point
|
280 |
+
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
|
281 |
+
'''
|
282 |
+
|
283 |
+
p_nor = p / (1 + padding + 10e-4) # (-0.5, 0.5)
|
284 |
+
p_nor = p_nor + 0.5 # range (0, 1)
|
285 |
+
# f there are outliers out of the range
|
286 |
+
if p_nor.max() >= 1:
|
287 |
+
p_nor[p_nor >= 1] = 1 - 10e-4
|
288 |
+
if p_nor.min() < 0:
|
289 |
+
p_nor[p_nor < 0] = 0.0
|
290 |
+
return p_nor
|
291 |
+
|
292 |
+
def normalize_coord(p, vol_range, plane='xz'):
|
293 |
+
''' Normalize coordinate to [0, 1] for sliding-window experiments
|
294 |
+
|
295 |
+
Args:
|
296 |
+
p (tensor): point
|
297 |
+
vol_range (numpy array): volume boundary
|
298 |
+
plane (str): feature type, ['xz', 'xy', 'yz'] - canonical planes; ['grid'] - grid volume
|
299 |
+
'''
|
300 |
+
p[:, 0] = (p[:, 0] - vol_range[0][0]) / (vol_range[1][0] - vol_range[0][0])
|
301 |
+
p[:, 1] = (p[:, 1] - vol_range[0][1]) / (vol_range[1][1] - vol_range[0][1])
|
302 |
+
p[:, 2] = (p[:, 2] - vol_range[0][2]) / (vol_range[1][2] - vol_range[0][2])
|
303 |
+
|
304 |
+
if plane == 'xz':
|
305 |
+
x = p[:, [0, 2]]
|
306 |
+
elif plane =='xy':
|
307 |
+
x = p[:, [0, 1]]
|
308 |
+
elif plane =='yz':
|
309 |
+
x = p[:, [1, 2]]
|
310 |
+
else:
|
311 |
+
x = p
|
312 |
+
return x
|
313 |
+
|
314 |
+
def coordinate2index(x, reso, coord_type='2d'):
|
315 |
+
''' Normalize coordinate to [0, 1] for unit cube experiments.
|
316 |
+
Corresponds to our 3D model
|
317 |
+
|
318 |
+
Args:
|
319 |
+
x (tensor): coordinate
|
320 |
+
reso (int): defined resolution
|
321 |
+
coord_type (str): coordinate type
|
322 |
+
'''
|
323 |
+
x = (x * reso).long()
|
324 |
+
if coord_type == '2d': # plane
|
325 |
+
index = x[:, :, 0] + reso * x[:, :, 1]
|
326 |
+
elif coord_type == '3d': # grid
|
327 |
+
index = x[:, :, 0] + reso * (x[:, :, 1] + reso * x[:, :, 2])
|
328 |
+
index = index[:, None, :]
|
329 |
+
return index
|
330 |
+
|
331 |
+
def coord2index(p, vol_range, reso=None, plane='xz'):
|
332 |
+
''' Normalize coordinate to [0, 1] for sliding-window experiments.
|
333 |
+
Corresponds to our 3D model
|
334 |
+
|
335 |
+
Args:
|
336 |
+
p (tensor): points
|
337 |
+
vol_range (numpy array): volume boundary
|
338 |
+
reso (int): defined resolution
|
339 |
+
plane (str): feature type, ['xz', 'xy', 'yz'] - canonical planes; ['grid'] - grid volume
|
340 |
+
'''
|
341 |
+
# normalize to [0, 1]
|
342 |
+
x = normalize_coord(p, vol_range, plane=plane)
|
343 |
+
|
344 |
+
if isinstance(x, np.ndarray):
|
345 |
+
x = np.floor(x * reso).astype(int)
|
346 |
+
else: #* pytorch tensor
|
347 |
+
x = (x * reso).long()
|
348 |
+
|
349 |
+
if x.shape[1] == 2:
|
350 |
+
index = x[:, 0] + reso * x[:, 1]
|
351 |
+
index[index > reso**2] = reso**2
|
352 |
+
elif x.shape[1] == 3:
|
353 |
+
index = x[:, 0] + reso * (x[:, 1] + reso * x[:, 2])
|
354 |
+
index[index > reso**3] = reso**3
|
355 |
+
|
356 |
+
return index[None]
|
357 |
+
|
358 |
+
def update_reso(reso, depth):
|
359 |
+
''' Update the defined resolution so that UNet can process.
|
360 |
+
|
361 |
+
Args:
|
362 |
+
reso (int): defined resolution
|
363 |
+
depth (int): U-Net number of layers
|
364 |
+
'''
|
365 |
+
base = 2**(int(depth) - 1)
|
366 |
+
if ~(reso / base).is_integer(): # when this is not integer, U-Net dimension error
|
367 |
+
for i in range(base):
|
368 |
+
if ((reso + i) / base).is_integer():
|
369 |
+
reso = reso + i
|
370 |
+
break
|
371 |
+
return reso
|
372 |
+
|
373 |
+
def decide_total_volume_range(query_vol_metric, recep_field, unit_size, unet_depth):
|
374 |
+
''' Update the defined resolution so that UNet can process.
|
375 |
+
|
376 |
+
Args:
|
377 |
+
query_vol_metric (numpy array): query volume size
|
378 |
+
recep_field (int): defined the receptive field for U-Net
|
379 |
+
unit_size (float): the defined voxel size
|
380 |
+
unet_depth (int): U-Net number of layers
|
381 |
+
'''
|
382 |
+
reso = query_vol_metric / unit_size + recep_field - 1
|
383 |
+
reso = update_reso(int(reso), unet_depth) # make sure input reso can be processed by UNet
|
384 |
+
input_vol_metric = reso * unit_size
|
385 |
+
p_c = np.array([0.0, 0.0, 0.0]).astype(np.float32)
|
386 |
+
lb_input_vol, ub_input_vol = p_c - input_vol_metric/2, p_c + input_vol_metric/2
|
387 |
+
lb_query_vol, ub_query_vol = p_c - query_vol_metric/2, p_c + query_vol_metric/2
|
388 |
+
input_vol = [lb_input_vol, ub_input_vol]
|
389 |
+
query_vol = [lb_query_vol, ub_query_vol]
|
390 |
+
|
391 |
+
# handle the case when resolution is too large
|
392 |
+
if reso > 10000:
|
393 |
+
reso = 1
|
394 |
+
|
395 |
+
return input_vol, query_vol, reso
|
396 |
+
|
397 |
+
def add_key(base, new, base_name, new_name, device=None):
|
398 |
+
''' Add new keys to the given input
|
399 |
+
|
400 |
+
Args:
|
401 |
+
base (tensor): inputs
|
402 |
+
new (tensor): new info for the inputs
|
403 |
+
base_name (str): name for the input
|
404 |
+
new_name (str): name for the new info
|
405 |
+
device (device): pytorch device
|
406 |
+
'''
|
407 |
+
if (new is not None) and (isinstance(new, dict)):
|
408 |
+
if device is not None:
|
409 |
+
for key in new.keys():
|
410 |
+
new[key] = new[key].to(device)
|
411 |
+
base = {base_name: base,
|
412 |
+
new_name: new}
|
413 |
+
return base
|
414 |
+
|
415 |
+
class map2local(object):
|
416 |
+
''' Add new keys to the given input
|
417 |
+
|
418 |
+
Args:
|
419 |
+
s (float): the defined voxel size
|
420 |
+
pos_encoding (str): method for the positional encoding, linear|sin_cos
|
421 |
+
'''
|
422 |
+
def __init__(self, s, pos_encoding='linear'):
|
423 |
+
super().__init__()
|
424 |
+
self.s = s
|
425 |
+
self.pe = positional_encoding(basis_function=pos_encoding)
|
426 |
+
|
427 |
+
def __call__(self, p):
|
428 |
+
p = torch.remainder(p, self.s) / self.s # always possitive
|
429 |
+
# p = torch.fmod(p, self.s) / self.s # same sign as input p!
|
430 |
+
p = self.pe(p)
|
431 |
+
return p
|
432 |
+
|
433 |
+
class positional_encoding(object):
|
434 |
+
''' Positional Encoding (presented in NeRF)
|
435 |
+
|
436 |
+
Args:
|
437 |
+
basis_function (str): basis function
|
438 |
+
'''
|
439 |
+
def __init__(self, basis_function='sin_cos'):
|
440 |
+
super().__init__()
|
441 |
+
self.func = basis_function
|
442 |
+
|
443 |
+
L = 10
|
444 |
+
freq_bands = 2.**(np.linspace(0, L-1, L))
|
445 |
+
self.freq_bands = freq_bands * math.pi
|
446 |
+
|
447 |
+
def __call__(self, p):
|
448 |
+
if self.func == 'sin_cos':
|
449 |
+
out = []
|
450 |
+
p = 2.0 * p - 1.0 # chagne to the range [-1, 1]
|
451 |
+
for freq in self.freq_bands:
|
452 |
+
out.append(torch.sin(freq * p))
|
453 |
+
out.append(torch.cos(freq * p))
|
454 |
+
p = torch.cat(out, dim=2)
|
455 |
+
return p
|
456 |
+
|
457 |
+
# Resnet Blocks
|
458 |
+
class ResnetBlockFC(nn.Module):
|
459 |
+
''' Fully connected ResNet Block class.
|
460 |
+
|
461 |
+
Args:
|
462 |
+
size_in (int): input dimension
|
463 |
+
size_out (int): output dimension
|
464 |
+
size_h (int): hidden dimension
|
465 |
+
'''
|
466 |
+
|
467 |
+
def __init__(self, size_in, size_out=None, size_h=None):
|
468 |
+
super().__init__()
|
469 |
+
# Attributes
|
470 |
+
if size_out is None:
|
471 |
+
size_out = size_in
|
472 |
+
|
473 |
+
if size_h is None:
|
474 |
+
size_h = min(size_in, size_out)
|
475 |
+
|
476 |
+
self.size_in = size_in
|
477 |
+
self.size_h = size_h
|
478 |
+
self.size_out = size_out
|
479 |
+
# Submodules
|
480 |
+
self.fc_0 = nn.Linear(size_in, size_h)
|
481 |
+
self.fc_1 = nn.Linear(size_h, size_out)
|
482 |
+
self.actvn = nn.ReLU()
|
483 |
+
|
484 |
+
if size_in == size_out:
|
485 |
+
self.shortcut = None
|
486 |
+
else:
|
487 |
+
self.shortcut = nn.Linear(size_in, size_out, bias=False)
|
488 |
+
# Initialization
|
489 |
+
nn.init.zeros_(self.fc_1.weight)
|
490 |
+
|
491 |
+
def forward(self, x):
|
492 |
+
net = self.fc_0(self.actvn(x))
|
493 |
+
dx = self.fc_1(self.actvn(net))
|
494 |
+
|
495 |
+
if self.shortcut is not None:
|
496 |
+
x_s = self.shortcut(x)
|
497 |
+
else:
|
498 |
+
x_s = x
|
499 |
+
|
500 |
+
return x_s + dx
|
501 |
+
|
502 |
+
|
503 |
+
|
504 |
+
'''
|
505 |
+
------------------ the key model for Pointnet ----------------------------
|
506 |
+
'''
|
507 |
+
|
508 |
+
|
509 |
+
class LocalSoftSplat(nn.Module):
|
510 |
+
|
511 |
+
def __init__(self, ch=128, dim=3, hidden_dim=128, scatter_type='max',
|
512 |
+
unet=True, unet_kwargs=None, unet3d=False, unet3d_kwargs=None,
|
513 |
+
hw=None, grid_resolution=None, plane_type='xz', padding=0.1,
|
514 |
+
n_blocks=4, splat_func=None):
|
515 |
+
super().__init__()
|
516 |
+
c_dim = ch
|
517 |
+
|
518 |
+
self.c_dim = c_dim
|
519 |
+
|
520 |
+
self.fc_pos = nn.Linear(dim, 2*hidden_dim)
|
521 |
+
self.blocks = nn.ModuleList([
|
522 |
+
ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
|
523 |
+
])
|
524 |
+
self.fc_c = nn.Linear(hidden_dim, c_dim)
|
525 |
+
|
526 |
+
self.actvn = nn.ReLU()
|
527 |
+
self.hidden_dim = hidden_dim
|
528 |
+
|
529 |
+
if unet:
|
530 |
+
self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs)
|
531 |
+
else:
|
532 |
+
self.unet = None
|
533 |
+
|
534 |
+
# get splat func
|
535 |
+
self.splat_func = splat_func
|
536 |
+
def forward(self, img_feat,
|
537 |
+
Fxy2xz, Fxy2yz, Dz, gridxy=None):
|
538 |
+
"""
|
539 |
+
Args:
|
540 |
+
img_feat (tensor): image features
|
541 |
+
Fxy2xz (tensor): transformation matrix from xy to xz
|
542 |
+
Fxy2yz (tensor): transformation matrix from xy to yz
|
543 |
+
"""
|
544 |
+
B, T, _, H, W = img_feat.shape
|
545 |
+
fea_reshp = rearrange(img_feat, 'b t c h w -> (b h w) t c',
|
546 |
+
c=img_feat.shape[2], h=H, w=W)
|
547 |
+
|
548 |
+
gridyz = gridxy + Fxy2yz
|
549 |
+
gridxz = gridxy + Fxy2xz
|
550 |
+
# normalize
|
551 |
+
gridyz[:, 0, ...] = (gridyz[:, 0, ...] / (H - 1) - 0.5) * 2
|
552 |
+
gridyz[:, 1, ...] = (gridyz[:, 1, ...] / (Dz - 1) - 0.5) * 2
|
553 |
+
gridxz[:, 0, ...] = (gridxz[:, 0, ...] / (W - 1) - 0.5) * 2
|
554 |
+
gridxz[:, 1, ...] = (gridxz[:, 1, ...] / (Dz - 1) - 0.5) * 2
|
555 |
+
if len(self.blocks) > 0:
|
556 |
+
net = self.fc_pos(fea_reshp)
|
557 |
+
net = self.blocks[0](net)
|
558 |
+
for block in self.blocks[1:]:
|
559 |
+
# splat and fusion
|
560 |
+
net_plane = rearrange(net, '(b h w) t c -> (b t) c h w', b=B, h=H, w=W)
|
561 |
+
|
562 |
+
net_planeYZ = self.splat_func(net_plane, Fxy2yz, None,
|
563 |
+
strMode="avg", tenoutH=Dz, tenoutW=H)
|
564 |
+
|
565 |
+
net_planeXZ = self.splat_func(net_plane, Fxy2xz, None,
|
566 |
+
strMode="avg", tenoutH=Dz, tenoutW=W)
|
567 |
+
|
568 |
+
net_plane = net_plane + (
|
569 |
+
F.grid_sample(
|
570 |
+
net_planeYZ, gridyz.permute(0,2,3,1), mode='bilinear', padding_mode='border') +
|
571 |
+
F.grid_sample(
|
572 |
+
net_planeXZ, gridxz.permute(0,2,3,1), mode='bilinear', padding_mode='border')
|
573 |
+
)
|
574 |
+
|
575 |
+
pooled = rearrange(net_plane, 't c h w -> (h w) t c',
|
576 |
+
c=net_plane.shape[1], h=H, w=W)
|
577 |
+
|
578 |
+
net = torch.cat([net, pooled], dim=2)
|
579 |
+
net = block(net)
|
580 |
+
|
581 |
+
c = self.fc_c(net)
|
582 |
+
net_plane = rearrange(c, '(b h w) t c -> (b t) c h w', b=B, h=H, w=W)
|
583 |
+
else:
|
584 |
+
net_plane = rearrange(img_feat, 'b t c h w -> (b t) c h w',
|
585 |
+
c=img_feat.shape[2], h=H, w=W)
|
586 |
+
net_planeYZ = self.splat_func(net_plane, Fxy2yz, None,
|
587 |
+
strMode="avg", tenoutH=Dz, tenoutW=H)
|
588 |
+
net_planeXZ = self.splat_func(net_plane, Fxy2xz, None,
|
589 |
+
strMode="avg", tenoutH=Dz, tenoutW=W)
|
590 |
+
|
591 |
+
return net_plane[None], net_planeYZ[None], net_planeXZ[None]
|
592 |
+
|
593 |
+
|
594 |
+
|
595 |
+
class LocalPoolPointnet(nn.Module):
|
596 |
+
''' PointNet-based encoder network with ResNet blocks for each point.
|
597 |
+
Number of input points are fixed.
|
598 |
+
|
599 |
+
Args:
|
600 |
+
c_dim (int): dimension of latent code c
|
601 |
+
dim (int): input points dimension
|
602 |
+
hidden_dim (int): hidden dimension of the network
|
603 |
+
scatter_type (str): feature aggregation when doing local pooling
|
604 |
+
unet (bool): weather to use U-Net
|
605 |
+
unet_kwargs (str): U-Net parameters
|
606 |
+
unet3d (bool): weather to use 3D U-Net
|
607 |
+
unet3d_kwargs (str): 3D U-Net parameters
|
608 |
+
plane_resolution (int): defined resolution for plane feature
|
609 |
+
grid_resolution (int): defined resolution for grid feature
|
610 |
+
plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
|
611 |
+
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
|
612 |
+
n_blocks (int): number of blocks ResNetBlockFC layers
|
613 |
+
'''
|
614 |
+
|
615 |
+
def __init__(self, ch=128, dim=3, hidden_dim=128, scatter_type='max',
|
616 |
+
unet=True, unet_kwargs=None, unet3d=False, unet3d_kwargs=None,
|
617 |
+
hw=None, grid_resolution=None, plane_type='xz', padding=0.1, n_blocks=5):
|
618 |
+
super().__init__()
|
619 |
+
c_dim = ch
|
620 |
+
unet3d = False
|
621 |
+
plane_type = ['xy', 'xz', 'yz']
|
622 |
+
plane_resolution = hw
|
623 |
+
|
624 |
+
self.c_dim = c_dim
|
625 |
+
|
626 |
+
self.fc_pos = nn.Linear(dim, 2*hidden_dim)
|
627 |
+
self.blocks = nn.ModuleList([
|
628 |
+
ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
|
629 |
+
])
|
630 |
+
self.fc_c = nn.Linear(hidden_dim, c_dim)
|
631 |
+
|
632 |
+
self.actvn = nn.ReLU()
|
633 |
+
self.hidden_dim = hidden_dim
|
634 |
+
|
635 |
+
if unet:
|
636 |
+
self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs)
|
637 |
+
else:
|
638 |
+
self.unet = None
|
639 |
+
|
640 |
+
if unet3d:
|
641 |
+
# self.unet3d = UNet3D(**unet3d_kwargs)
|
642 |
+
raise NotImplementedError()
|
643 |
+
else:
|
644 |
+
self.unet3d = None
|
645 |
+
|
646 |
+
self.reso_plane = plane_resolution
|
647 |
+
self.reso_grid = grid_resolution
|
648 |
+
self.plane_type = plane_type
|
649 |
+
self.padding = padding
|
650 |
+
|
651 |
+
if scatter_type == 'max':
|
652 |
+
self.scatter = scatter_max
|
653 |
+
elif scatter_type == 'mean':
|
654 |
+
self.scatter = scatter_mean
|
655 |
+
else:
|
656 |
+
raise ValueError('incorrect scatter type')
|
657 |
+
|
658 |
+
def generate_plane_features(self, p, c, plane='xz'):
|
659 |
+
# acquire indices of features in plane
|
660 |
+
xy = normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1)
|
661 |
+
index = coordinate2index(xy, self.reso_plane)
|
662 |
+
|
663 |
+
# scatter plane features from points
|
664 |
+
fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2)
|
665 |
+
c = c.permute(0, 2, 1) # B x 512 x T
|
666 |
+
fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
|
667 |
+
fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso)
|
668 |
+
|
669 |
+
# process the plane features with UNet
|
670 |
+
if self.unet is not None:
|
671 |
+
fea_plane = self.unet(fea_plane)
|
672 |
+
|
673 |
+
return fea_plane
|
674 |
+
|
675 |
+
def generate_grid_features(self, p, c):
|
676 |
+
p_nor = normalize_3d_coordinate(p.clone(), padding=self.padding)
|
677 |
+
index = coordinate2index(p_nor, self.reso_grid, coord_type='3d')
|
678 |
+
# scatter grid features from points
|
679 |
+
fea_grid = c.new_zeros(p.size(0), self.c_dim, self.reso_grid**3)
|
680 |
+
c = c.permute(0, 2, 1)
|
681 |
+
fea_grid = scatter_mean(c, index, out=fea_grid) # B x C x reso^3
|
682 |
+
fea_grid = fea_grid.reshape(p.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid) # sparce matrix (B x 512 x reso x reso)
|
683 |
+
|
684 |
+
if self.unet3d is not None:
|
685 |
+
fea_grid = self.unet3d(fea_grid)
|
686 |
+
|
687 |
+
return fea_grid
|
688 |
+
|
689 |
+
def pool_local(self, xy, index, c):
|
690 |
+
bs, fea_dim = c.size(0), c.size(2)
|
691 |
+
keys = xy.keys()
|
692 |
+
|
693 |
+
c_out = 0
|
694 |
+
for key in keys:
|
695 |
+
# scatter plane features from points
|
696 |
+
if key == 'grid':
|
697 |
+
fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_grid**3)
|
698 |
+
else:
|
699 |
+
c_permute = c.permute(0, 2, 1)
|
700 |
+
fea = self.scatter(c_permute, index[key], dim_size=self.reso_plane**2)
|
701 |
+
if self.scatter == scatter_max:
|
702 |
+
fea = fea[0]
|
703 |
+
# gather feature back to points
|
704 |
+
fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
|
705 |
+
c_out = c_out + fea
|
706 |
+
return c_out.permute(0, 2, 1)
|
707 |
+
|
708 |
+
|
709 |
+
def forward(self, p_input, img_feats=None):
|
710 |
+
"""
|
711 |
+
Args:
|
712 |
+
p_input (tensor): input points T 3 H W
|
713 |
+
img_feats (tensor): image features T C H W
|
714 |
+
"""
|
715 |
+
T, _, H, W = img_feats.size()
|
716 |
+
p = rearrange(p_input, 't c h w -> (h w) t c', c=3, h=H, w=W)
|
717 |
+
fea_reshp = rearrange(img_feats, 't c h w -> (h w) t c',
|
718 |
+
c=img_feats.shape[1], h=H, w=W)
|
719 |
+
|
720 |
+
# acquire the index for each point
|
721 |
+
coord = {}
|
722 |
+
index = {}
|
723 |
+
if 'xz' in self.plane_type:
|
724 |
+
coord['xz'] = normalize_coordinate(p.clone(), plane='xz', padding=self.padding)
|
725 |
+
index['xz'] = coordinate2index(coord['xz'], self.reso_plane)
|
726 |
+
if 'xy' in self.plane_type:
|
727 |
+
coord['xy'] = normalize_coordinate(p.clone(), plane='xy', padding=self.padding)
|
728 |
+
index['xy'] = coordinate2index(coord['xy'], self.reso_plane)
|
729 |
+
if 'yz' in self.plane_type:
|
730 |
+
coord['yz'] = normalize_coordinate(p.clone(), plane='yz', padding=self.padding)
|
731 |
+
index['yz'] = coordinate2index(coord['yz'], self.reso_plane)
|
732 |
+
if 'grid' in self.plane_type:
|
733 |
+
coord['grid'] = normalize_3d_coordinate(p.clone(), padding=self.padding)
|
734 |
+
index['grid'] = coordinate2index(coord['grid'], self.reso_grid, coord_type='3d')
|
735 |
+
|
736 |
+
net = self.fc_pos(p) + fea_reshp
|
737 |
+
net = self.blocks[0](net)
|
738 |
+
for block in self.blocks[1:]:
|
739 |
+
pooled = self.pool_local(coord, index, net)
|
740 |
+
net = torch.cat([net, pooled], dim=2)
|
741 |
+
net = block(net)
|
742 |
+
|
743 |
+
c = self.fc_c(net)
|
744 |
+
|
745 |
+
fea = {}
|
746 |
+
|
747 |
+
if 'grid' in self.plane_type:
|
748 |
+
fea['grid'] = self.generate_grid_features(p, c)
|
749 |
+
if 'xz' in self.plane_type:
|
750 |
+
fea['xz'] = self.generate_plane_features(p, c, plane='xz')
|
751 |
+
if 'xy' in self.plane_type:
|
752 |
+
fea['xy'] = self.generate_plane_features(p, c, plane='xy')
|
753 |
+
if 'yz' in self.plane_type:
|
754 |
+
fea['yz'] = self.generate_plane_features(p, c, plane='yz')
|
755 |
+
|
756 |
+
ret = torch.stack([fea['xy'], fea['xz'], fea['yz']]).permute((1, 0, 2, 3, 4))
|
757 |
+
return ret
|
758 |
+
|
759 |
+
class PatchLocalPoolPointnet(nn.Module):
|
760 |
+
''' PointNet-based encoder network with ResNet blocks.
|
761 |
+
First transform input points to local system based on the given voxel size.
|
762 |
+
Support non-fixed number of point cloud, but need to precompute the index
|
763 |
+
|
764 |
+
Args:
|
765 |
+
c_dim (int): dimension of latent code c
|
766 |
+
dim (int): input points dimension
|
767 |
+
hidden_dim (int): hidden dimension of the network
|
768 |
+
scatter_type (str): feature aggregation when doing local pooling
|
769 |
+
unet (bool): weather to use U-Net
|
770 |
+
unet_kwargs (str): U-Net parameters
|
771 |
+
unet3d (bool): weather to use 3D U-Net
|
772 |
+
unet3d_kwargs (str): 3D U-Net parameters
|
773 |
+
plane_resolution (int): defined resolution for plane feature
|
774 |
+
grid_resolution (int): defined resolution for grid feature
|
775 |
+
plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
|
776 |
+
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
|
777 |
+
n_blocks (int): number of blocks ResNetBlockFC layers
|
778 |
+
local_coord (bool): whether to use local coordinate
|
779 |
+
pos_encoding (str): method for the positional encoding, linear|sin_cos
|
780 |
+
unit_size (float): defined voxel unit size for local system
|
781 |
+
'''
|
782 |
+
|
783 |
+
def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max',
|
784 |
+
unet=False, unet_kwargs=None, unet3d=False, unet3d_kwargs=None,
|
785 |
+
plane_resolution=None, grid_resolution=None, plane_type='xz', padding=0.1, n_blocks=5,
|
786 |
+
local_coord=False, pos_encoding='linear', unit_size=0.1):
|
787 |
+
super().__init__()
|
788 |
+
self.c_dim = c_dim
|
789 |
+
|
790 |
+
self.blocks = nn.ModuleList([
|
791 |
+
ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
|
792 |
+
])
|
793 |
+
self.fc_c = nn.Linear(hidden_dim, c_dim)
|
794 |
+
|
795 |
+
self.actvn = nn.ReLU()
|
796 |
+
self.hidden_dim = hidden_dim
|
797 |
+
self.reso_plane = plane_resolution
|
798 |
+
self.reso_grid = grid_resolution
|
799 |
+
self.plane_type = plane_type
|
800 |
+
self.padding = padding
|
801 |
+
|
802 |
+
if unet:
|
803 |
+
self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs)
|
804 |
+
else:
|
805 |
+
self.unet = None
|
806 |
+
|
807 |
+
if unet3d:
|
808 |
+
# self.unet3d = UNet3D(**unet3d_kwargs)
|
809 |
+
raise NotImplementedError()
|
810 |
+
else:
|
811 |
+
self.unet3d = None
|
812 |
+
|
813 |
+
if scatter_type == 'max':
|
814 |
+
self.scatter = scatter_max
|
815 |
+
elif scatter_type == 'mean':
|
816 |
+
self.scatter = scatter_mean
|
817 |
+
else:
|
818 |
+
raise ValueError('incorrect scatter type')
|
819 |
+
|
820 |
+
if local_coord:
|
821 |
+
self.map2local = map2local(unit_size, pos_encoding=pos_encoding)
|
822 |
+
else:
|
823 |
+
self.map2local = None
|
824 |
+
|
825 |
+
if pos_encoding == 'sin_cos':
|
826 |
+
self.fc_pos = nn.Linear(60, 2*hidden_dim)
|
827 |
+
else:
|
828 |
+
self.fc_pos = nn.Linear(dim, 2*hidden_dim)
|
829 |
+
|
830 |
+
def generate_plane_features(self, index, c):
|
831 |
+
c = c.permute(0, 2, 1)
|
832 |
+
# scatter plane features from points
|
833 |
+
if index.max() < self.reso_plane**2:
|
834 |
+
fea_plane = c.new_zeros(c.size(0), self.c_dim, self.reso_plane**2)
|
835 |
+
fea_plane = scatter_mean(c, index, out=fea_plane) # B x c_dim x reso^2
|
836 |
+
else:
|
837 |
+
fea_plane = scatter_mean(c, index) # B x c_dim x reso^2
|
838 |
+
if fea_plane.shape[-1] > self.reso_plane**2: # deal with outliers
|
839 |
+
fea_plane = fea_plane[:, :, :-1]
|
840 |
+
|
841 |
+
fea_plane = fea_plane.reshape(c.size(0), self.c_dim, self.reso_plane, self.reso_plane)
|
842 |
+
|
843 |
+
# process the plane features with UNet
|
844 |
+
if self.unet is not None:
|
845 |
+
fea_plane = self.unet(fea_plane)
|
846 |
+
|
847 |
+
return fea_plane
|
848 |
+
|
849 |
+
def generate_grid_features(self, index, c):
|
850 |
+
# scatter grid features from points
|
851 |
+
c = c.permute(0, 2, 1)
|
852 |
+
if index.max() < self.reso_grid**3:
|
853 |
+
fea_grid = c.new_zeros(c.size(0), self.c_dim, self.reso_grid**3)
|
854 |
+
fea_grid = scatter_mean(c, index, out=fea_grid) # B x c_dim x reso^3
|
855 |
+
else:
|
856 |
+
fea_grid = scatter_mean(c, index) # B x c_dim x reso^3
|
857 |
+
if fea_grid.shape[-1] > self.reso_grid**3: # deal with outliers
|
858 |
+
fea_grid = fea_grid[:, :, :-1]
|
859 |
+
fea_grid = fea_grid.reshape(c.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid)
|
860 |
+
|
861 |
+
if self.unet3d is not None:
|
862 |
+
fea_grid = self.unet3d(fea_grid)
|
863 |
+
|
864 |
+
return fea_grid
|
865 |
+
|
866 |
+
def pool_local(self, index, c):
|
867 |
+
bs, fea_dim = c.size(0), c.size(2)
|
868 |
+
keys = index.keys()
|
869 |
+
|
870 |
+
c_out = 0
|
871 |
+
for key in keys:
|
872 |
+
# scatter plane features from points
|
873 |
+
if key == 'grid':
|
874 |
+
fea = self.scatter(c.permute(0, 2, 1), index[key])
|
875 |
+
else:
|
876 |
+
fea = self.scatter(c.permute(0, 2, 1), index[key])
|
877 |
+
if self.scatter == scatter_max:
|
878 |
+
fea = fea[0]
|
879 |
+
# gather feature back to points
|
880 |
+
fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
|
881 |
+
c_out += fea
|
882 |
+
return c_out.permute(0, 2, 1)
|
883 |
+
|
884 |
+
|
885 |
+
def forward(self, inputs):
|
886 |
+
p = inputs['points']
|
887 |
+
index = inputs['index']
|
888 |
+
|
889 |
+
batch_size, T, D = p.size()
|
890 |
+
|
891 |
+
if self.map2local:
|
892 |
+
pp = self.map2local(p)
|
893 |
+
net = self.fc_pos(pp)
|
894 |
+
else:
|
895 |
+
net = self.fc_pos(p)
|
896 |
+
|
897 |
+
net = self.blocks[0](net)
|
898 |
+
for block in self.blocks[1:]:
|
899 |
+
pooled = self.pool_local(index, net)
|
900 |
+
net = torch.cat([net, pooled], dim=2)
|
901 |
+
net = block(net)
|
902 |
+
|
903 |
+
c = self.fc_c(net)
|
904 |
+
|
905 |
+
fea = {}
|
906 |
+
if 'grid' in self.plane_type:
|
907 |
+
fea['grid'] = self.generate_grid_features(index['grid'], c)
|
908 |
+
if 'xz' in self.plane_type:
|
909 |
+
fea['xz'] = self.generate_plane_features(index['xz'], c)
|
910 |
+
if 'xy' in self.plane_type:
|
911 |
+
fea['xy'] = self.generate_plane_features(index['xy'], c)
|
912 |
+
if 'yz' in self.plane_type:
|
913 |
+
fea['yz'] = self.generate_plane_features(index['yz'], c)
|
914 |
+
|
915 |
+
return fea
|
models/spatracker/models/core/spatracker/loftr/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .transformer import LocalFeatureTransformer
|
models/spatracker/models/core/spatracker/loftr/linear_attention.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
|
3 |
+
Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.nn import Module, Dropout
|
8 |
+
|
9 |
+
|
10 |
+
def elu_feature_map(x):
|
11 |
+
return torch.nn.functional.elu(x) + 1
|
12 |
+
|
13 |
+
|
14 |
+
class LinearAttention(Module):
|
15 |
+
def __init__(self, eps=1e-6):
|
16 |
+
super().__init__()
|
17 |
+
self.feature_map = elu_feature_map
|
18 |
+
self.eps = eps
|
19 |
+
|
20 |
+
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
|
21 |
+
""" Multi-Head linear attention proposed in "Transformers are RNNs"
|
22 |
+
Args:
|
23 |
+
queries: [N, L, H, D]
|
24 |
+
keys: [N, S, H, D]
|
25 |
+
values: [N, S, H, D]
|
26 |
+
q_mask: [N, L]
|
27 |
+
kv_mask: [N, S]
|
28 |
+
Returns:
|
29 |
+
queried_values: (N, L, H, D)
|
30 |
+
"""
|
31 |
+
Q = self.feature_map(queries)
|
32 |
+
K = self.feature_map(keys)
|
33 |
+
|
34 |
+
# set padded position to zero
|
35 |
+
if q_mask is not None:
|
36 |
+
Q = Q * q_mask[:, :, None, None]
|
37 |
+
if kv_mask is not None:
|
38 |
+
K = K * kv_mask[:, :, None, None]
|
39 |
+
values = values * kv_mask[:, :, None, None]
|
40 |
+
|
41 |
+
v_length = values.size(1)
|
42 |
+
values = values / v_length # prevent fp16 overflow
|
43 |
+
KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
|
44 |
+
Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
|
45 |
+
queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
|
46 |
+
|
47 |
+
return queried_values.contiguous()
|
48 |
+
|
49 |
+
|
50 |
+
class FullAttention(Module):
|
51 |
+
def __init__(self, use_dropout=False, attention_dropout=0.1):
|
52 |
+
super().__init__()
|
53 |
+
self.use_dropout = use_dropout
|
54 |
+
self.dropout = Dropout(attention_dropout)
|
55 |
+
|
56 |
+
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
|
57 |
+
""" Multi-head scaled dot-product attention, a.k.a full attention.
|
58 |
+
Args:
|
59 |
+
queries: [N, L, H, D]
|
60 |
+
keys: [N, S, H, D]
|
61 |
+
values: [N, S, H, D]
|
62 |
+
q_mask: [N, L]
|
63 |
+
kv_mask: [N, S]
|
64 |
+
Returns:
|
65 |
+
queried_values: (N, L, H, D)
|
66 |
+
"""
|
67 |
+
|
68 |
+
# Compute the unnormalized attention and apply the masks
|
69 |
+
QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
|
70 |
+
if kv_mask is not None:
|
71 |
+
QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf'))
|
72 |
+
|
73 |
+
# Compute the attention and the weighted average
|
74 |
+
softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
|
75 |
+
A = torch.softmax(softmax_temp * QK, dim=2)
|
76 |
+
if self.use_dropout:
|
77 |
+
A = self.dropout(A)
|
78 |
+
|
79 |
+
queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
|
80 |
+
|
81 |
+
return queried_values.contiguous()
|
models/spatracker/models/core/spatracker/loftr/transformer.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
modified from
|
3 |
+
https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py
|
4 |
+
'''
|
5 |
+
import torch
|
6 |
+
from torch.nn import Module, Dropout
|
7 |
+
import copy
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
def elu_feature_map(x):
|
13 |
+
return torch.nn.functional.elu(x) + 1
|
14 |
+
|
15 |
+
class FullAttention(Module):
|
16 |
+
def __init__(self, use_dropout=False, attention_dropout=0.1):
|
17 |
+
super().__init__()
|
18 |
+
self.use_dropout = use_dropout
|
19 |
+
self.dropout = Dropout(attention_dropout)
|
20 |
+
|
21 |
+
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
|
22 |
+
""" Multi-head scaled dot-product attention, a.k.a full attention.
|
23 |
+
Args:
|
24 |
+
queries: [N, L, H, D]
|
25 |
+
keys: [N, S, H, D]
|
26 |
+
values: [N, S, H, D]
|
27 |
+
q_mask: [N, L]
|
28 |
+
kv_mask: [N, S]
|
29 |
+
Returns:
|
30 |
+
queried_values: (N, L, H, D)
|
31 |
+
"""
|
32 |
+
|
33 |
+
# Compute the unnormalized attention and apply the masks
|
34 |
+
# QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
|
35 |
+
# if kv_mask is not None:
|
36 |
+
# QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float(-1e12))
|
37 |
+
# softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
|
38 |
+
# A = torch.softmax(softmax_temp * QK, dim=2)
|
39 |
+
# if self.use_dropout:
|
40 |
+
# A = self.dropout(A)
|
41 |
+
# queried_values_ = torch.einsum("nlsh,nshd->nlhd", A, values)
|
42 |
+
|
43 |
+
# Compute the attention and the weighted average
|
44 |
+
input_args = [x.half().contiguous() for x in [queries.permute(0,2,1,3), keys.permute(0,2,1,3), values.permute(0,2,1,3)]]
|
45 |
+
queried_values = F.scaled_dot_product_attention(*input_args).permute(0,2,1,3).float() # type: ignore
|
46 |
+
|
47 |
+
|
48 |
+
return queried_values.contiguous()
|
49 |
+
|
50 |
+
class TransformerEncoderLayer(nn.Module):
|
51 |
+
def __init__(self,
|
52 |
+
d_model,
|
53 |
+
nhead,):
|
54 |
+
super(TransformerEncoderLayer, self).__init__()
|
55 |
+
|
56 |
+
self.dim = d_model // nhead
|
57 |
+
self.nhead = nhead
|
58 |
+
|
59 |
+
# multi-head attention
|
60 |
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
61 |
+
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
62 |
+
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
63 |
+
self.attention = FullAttention()
|
64 |
+
self.merge = nn.Linear(d_model, d_model, bias=False)
|
65 |
+
|
66 |
+
# feed-forward network
|
67 |
+
self.mlp = nn.Sequential(
|
68 |
+
nn.Linear(d_model*2, d_model*2, bias=False),
|
69 |
+
nn.ReLU(True),
|
70 |
+
nn.Linear(d_model*2, d_model, bias=False),
|
71 |
+
)
|
72 |
+
|
73 |
+
# norm and dropout
|
74 |
+
self.norm1 = nn.LayerNorm(d_model)
|
75 |
+
self.norm2 = nn.LayerNorm(d_model)
|
76 |
+
|
77 |
+
def forward(self, x, source, x_mask=None, source_mask=None):
|
78 |
+
"""
|
79 |
+
Args:
|
80 |
+
x (torch.Tensor): [N, L, C]
|
81 |
+
source (torch.Tensor): [N, S, C]
|
82 |
+
x_mask (torch.Tensor): [N, L] (optional)
|
83 |
+
source_mask (torch.Tensor): [N, S] (optional)
|
84 |
+
"""
|
85 |
+
bs = x.size(0)
|
86 |
+
query, key, value = x, source, source
|
87 |
+
|
88 |
+
# multi-head attention
|
89 |
+
query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
|
90 |
+
key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
|
91 |
+
value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
|
92 |
+
message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
|
93 |
+
message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
94 |
+
message = self.norm1(message)
|
95 |
+
|
96 |
+
# feed-forward network
|
97 |
+
message = self.mlp(torch.cat([x, message], dim=2))
|
98 |
+
message = self.norm2(message)
|
99 |
+
|
100 |
+
return x + message
|
101 |
+
|
102 |
+
class LocalFeatureTransformer(nn.Module):
|
103 |
+
"""A Local Feature Transformer module."""
|
104 |
+
|
105 |
+
def __init__(self, config):
|
106 |
+
super(LocalFeatureTransformer, self).__init__()
|
107 |
+
|
108 |
+
self.config = config
|
109 |
+
self.d_model = config['d_model']
|
110 |
+
self.nhead = config['nhead']
|
111 |
+
self.layer_names = config['layer_names']
|
112 |
+
encoder_layer = TransformerEncoderLayer(config['d_model'], config['nhead'])
|
113 |
+
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
|
114 |
+
self._reset_parameters()
|
115 |
+
|
116 |
+
def _reset_parameters(self):
|
117 |
+
for p in self.parameters():
|
118 |
+
if p.dim() > 1:
|
119 |
+
nn.init.xavier_uniform_(p)
|
120 |
+
|
121 |
+
def forward(self, feat0, feat1, mask0=None, mask1=None):
|
122 |
+
"""
|
123 |
+
Args:
|
124 |
+
feat0 (torch.Tensor): [N, L, C]
|
125 |
+
feat1 (torch.Tensor): [N, S, C]
|
126 |
+
mask0 (torch.Tensor): [N, L] (optional)
|
127 |
+
mask1 (torch.Tensor): [N, S] (optional)
|
128 |
+
"""
|
129 |
+
|
130 |
+
assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
|
131 |
+
|
132 |
+
for layer, name in zip(self.layers, self.layer_names):
|
133 |
+
if name == 'self':
|
134 |
+
feat0 = layer(feat0, feat0, mask0, mask0)
|
135 |
+
feat1 = layer(feat1, feat1, mask1, mask1)
|
136 |
+
elif name == 'cross':
|
137 |
+
feat0 = layer(feat0, feat1, mask0, mask1)
|
138 |
+
feat1 = layer(feat1, feat0, mask1, mask0)
|
139 |
+
else:
|
140 |
+
raise KeyError
|
141 |
+
|
142 |
+
return feat0, feat1
|
models/spatracker/models/core/spatracker/losses.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from models.spatracker.models.core.model_utils import reduce_masked_mean
|
10 |
+
from models.spatracker.models.core.spatracker.blocks import (
|
11 |
+
pix2cam
|
12 |
+
)
|
13 |
+
from models.spatracker.models.core.model_utils import (
|
14 |
+
bilinear_sample2d
|
15 |
+
)
|
16 |
+
|
17 |
+
EPS = 1e-6
|
18 |
+
import torchvision.transforms.functional as TF
|
19 |
+
|
20 |
+
sigma = 3
|
21 |
+
x_grid = torch.arange(-7,8,1)
|
22 |
+
y_grid = torch.arange(-7,8,1)
|
23 |
+
x_grid, y_grid = torch.meshgrid(x_grid, y_grid)
|
24 |
+
gridxy = torch.stack([x_grid, y_grid], dim=-1).float()
|
25 |
+
gs_kernel = torch.exp(-torch.sum(gridxy**2, dim=-1)/(2*sigma**2))
|
26 |
+
|
27 |
+
|
28 |
+
def balanced_ce_loss(pred, gt, valid=None):
|
29 |
+
total_balanced_loss = 0.0
|
30 |
+
for j in range(len(gt)):
|
31 |
+
B, S, N = gt[j].shape
|
32 |
+
# pred and gt are the same shape
|
33 |
+
for (a, b) in zip(pred[j].size(), gt[j].size()):
|
34 |
+
assert a == b # some shape mismatch!
|
35 |
+
# if valid is not None:
|
36 |
+
for (a, b) in zip(pred[j].size(), valid[j].size()):
|
37 |
+
assert a == b # some shape mismatch!
|
38 |
+
|
39 |
+
pos = (gt[j] > 0.95).float()
|
40 |
+
neg = (gt[j] < 0.05).float()
|
41 |
+
|
42 |
+
label = pos * 2.0 - 1.0
|
43 |
+
a = -label * pred[j]
|
44 |
+
b = F.relu(a)
|
45 |
+
loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))
|
46 |
+
|
47 |
+
pos_loss = reduce_masked_mean(loss, pos * valid[j])
|
48 |
+
neg_loss = reduce_masked_mean(loss, neg * valid[j])
|
49 |
+
balanced_loss = pos_loss + neg_loss
|
50 |
+
total_balanced_loss += balanced_loss / float(N)
|
51 |
+
import ipdb; ipdb.set_trace()
|
52 |
+
return total_balanced_loss
|
53 |
+
|
54 |
+
|
55 |
+
def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8,
|
56 |
+
intr=None, trajs_g_all=None):
|
57 |
+
"""Loss function defined over sequence of flow predictions"""
|
58 |
+
total_flow_loss = 0.0
|
59 |
+
|
60 |
+
for j in range(len(flow_gt)):
|
61 |
+
B, S, N, D = flow_gt[j].shape
|
62 |
+
# assert D == 3
|
63 |
+
B, S1, N = vis[j].shape
|
64 |
+
B, S2, N = valids[j].shape
|
65 |
+
assert S == S1
|
66 |
+
assert S == S2
|
67 |
+
n_predictions = len(flow_preds[j])
|
68 |
+
if intr is not None:
|
69 |
+
intr_i = intr[j]
|
70 |
+
flow_loss = 0.0
|
71 |
+
for i in range(n_predictions):
|
72 |
+
i_weight = gamma ** (n_predictions - i - 1)
|
73 |
+
flow_pred = flow_preds[j][i][..., -N:, :D]
|
74 |
+
flow_gt_j = flow_gt[j].clone()
|
75 |
+
if intr is not None:
|
76 |
+
xyz_j_gt = pix2cam(flow_gt_j, intr_i)
|
77 |
+
try:
|
78 |
+
i_loss = (flow_pred - flow_gt_j).abs() # B, S, N, 3
|
79 |
+
except:
|
80 |
+
import ipdb; ipdb.set_trace()
|
81 |
+
if D==3:
|
82 |
+
i_loss[...,2]*=30
|
83 |
+
i_loss = torch.mean(i_loss, dim=3) # B, S, N
|
84 |
+
flow_loss += i_weight * (reduce_masked_mean(i_loss, valids[j]))
|
85 |
+
|
86 |
+
flow_loss = flow_loss / n_predictions
|
87 |
+
total_flow_loss += flow_loss / float(N)
|
88 |
+
|
89 |
+
|
90 |
+
return total_flow_loss
|
models/spatracker/models/core/spatracker/softsplat.py
ADDED
@@ -0,0 +1,539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
"""The code of softsplat function is modified from:
|
4 |
+
https://github.com/sniklaus/softmax-splatting/blob/master/softsplat.py
|
5 |
+
|
6 |
+
"""
|
7 |
+
|
8 |
+
|
9 |
+
import collections
|
10 |
+
import cupy
|
11 |
+
import os
|
12 |
+
import re
|
13 |
+
import torch
|
14 |
+
import typing
|
15 |
+
|
16 |
+
|
17 |
+
##########################################################
|
18 |
+
|
19 |
+
|
20 |
+
objCudacache = {}
|
21 |
+
|
22 |
+
|
23 |
+
def cuda_int32(intIn:int):
|
24 |
+
return cupy.int32(intIn)
|
25 |
+
# end
|
26 |
+
|
27 |
+
|
28 |
+
def cuda_float32(fltIn:float):
|
29 |
+
return cupy.float32(fltIn)
|
30 |
+
# end
|
31 |
+
|
32 |
+
|
33 |
+
def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict):
|
34 |
+
if 'device' not in objCudacache:
|
35 |
+
objCudacache['device'] = torch.cuda.get_device_name()
|
36 |
+
# end
|
37 |
+
|
38 |
+
strKey = strFunction
|
39 |
+
|
40 |
+
for strVariable in objVariables:
|
41 |
+
objValue = objVariables[strVariable]
|
42 |
+
|
43 |
+
strKey += strVariable
|
44 |
+
|
45 |
+
if objValue is None:
|
46 |
+
continue
|
47 |
+
|
48 |
+
elif type(objValue) == int:
|
49 |
+
strKey += str(objValue)
|
50 |
+
|
51 |
+
elif type(objValue) == float:
|
52 |
+
strKey += str(objValue)
|
53 |
+
|
54 |
+
elif type(objValue) == bool:
|
55 |
+
strKey += str(objValue)
|
56 |
+
|
57 |
+
elif type(objValue) == str:
|
58 |
+
strKey += objValue
|
59 |
+
|
60 |
+
elif type(objValue) == torch.Tensor:
|
61 |
+
strKey += str(objValue.dtype)
|
62 |
+
strKey += str(objValue.shape)
|
63 |
+
strKey += str(objValue.stride())
|
64 |
+
|
65 |
+
elif True:
|
66 |
+
print(strVariable, type(objValue))
|
67 |
+
assert(False)
|
68 |
+
|
69 |
+
# end
|
70 |
+
# end
|
71 |
+
|
72 |
+
strKey += objCudacache['device']
|
73 |
+
|
74 |
+
if strKey not in objCudacache:
|
75 |
+
for strVariable in objVariables:
|
76 |
+
objValue = objVariables[strVariable]
|
77 |
+
|
78 |
+
if objValue is None:
|
79 |
+
continue
|
80 |
+
|
81 |
+
elif type(objValue) == int:
|
82 |
+
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
|
83 |
+
|
84 |
+
elif type(objValue) == float:
|
85 |
+
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
|
86 |
+
|
87 |
+
elif type(objValue) == bool:
|
88 |
+
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
|
89 |
+
|
90 |
+
elif type(objValue) == str:
|
91 |
+
strKernel = strKernel.replace('{{' + strVariable + '}}', objValue)
|
92 |
+
|
93 |
+
elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8:
|
94 |
+
strKernel = strKernel.replace('{{type}}', 'unsigned char')
|
95 |
+
|
96 |
+
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16:
|
97 |
+
strKernel = strKernel.replace('{{type}}', 'half')
|
98 |
+
|
99 |
+
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32:
|
100 |
+
strKernel = strKernel.replace('{{type}}', 'float')
|
101 |
+
|
102 |
+
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64:
|
103 |
+
strKernel = strKernel.replace('{{type}}', 'double')
|
104 |
+
|
105 |
+
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32:
|
106 |
+
strKernel = strKernel.replace('{{type}}', 'int')
|
107 |
+
|
108 |
+
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64:
|
109 |
+
strKernel = strKernel.replace('{{type}}', 'long')
|
110 |
+
|
111 |
+
elif type(objValue) == torch.Tensor:
|
112 |
+
print(strVariable, objValue.dtype)
|
113 |
+
assert(False)
|
114 |
+
|
115 |
+
elif True:
|
116 |
+
print(strVariable, type(objValue))
|
117 |
+
assert(False)
|
118 |
+
|
119 |
+
# end
|
120 |
+
# end
|
121 |
+
|
122 |
+
while True:
|
123 |
+
objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
|
124 |
+
|
125 |
+
if objMatch is None:
|
126 |
+
break
|
127 |
+
# end
|
128 |
+
|
129 |
+
intArg = int(objMatch.group(2))
|
130 |
+
|
131 |
+
strTensor = objMatch.group(4)
|
132 |
+
intSizes = objVariables[strTensor].size()
|
133 |
+
|
134 |
+
strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item()))
|
135 |
+
# end
|
136 |
+
|
137 |
+
while True:
|
138 |
+
objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel)
|
139 |
+
|
140 |
+
if objMatch is None:
|
141 |
+
break
|
142 |
+
# end
|
143 |
+
|
144 |
+
intStart = objMatch.span()[1]
|
145 |
+
intStop = objMatch.span()[1]
|
146 |
+
intParentheses = 1
|
147 |
+
|
148 |
+
while True:
|
149 |
+
intParentheses += 1 if strKernel[intStop] == '(' else 0
|
150 |
+
intParentheses -= 1 if strKernel[intStop] == ')' else 0
|
151 |
+
|
152 |
+
if intParentheses == 0:
|
153 |
+
break
|
154 |
+
# end
|
155 |
+
|
156 |
+
intStop += 1
|
157 |
+
# end
|
158 |
+
|
159 |
+
intArgs = int(objMatch.group(2))
|
160 |
+
strArgs = strKernel[intStart:intStop].split(',')
|
161 |
+
|
162 |
+
assert(intArgs == len(strArgs) - 1)
|
163 |
+
|
164 |
+
strTensor = strArgs[0]
|
165 |
+
intStrides = objVariables[strTensor].stride()
|
166 |
+
|
167 |
+
strIndex = []
|
168 |
+
|
169 |
+
for intArg in range(intArgs):
|
170 |
+
strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
|
171 |
+
# end
|
172 |
+
|
173 |
+
strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')')
|
174 |
+
# end
|
175 |
+
|
176 |
+
while True:
|
177 |
+
objMatch = re.search('(VALUE_)([0-4])(\()', strKernel)
|
178 |
+
|
179 |
+
if objMatch is None:
|
180 |
+
break
|
181 |
+
# end
|
182 |
+
|
183 |
+
intStart = objMatch.span()[1]
|
184 |
+
intStop = objMatch.span()[1]
|
185 |
+
intParentheses = 1
|
186 |
+
|
187 |
+
while True:
|
188 |
+
intParentheses += 1 if strKernel[intStop] == '(' else 0
|
189 |
+
intParentheses -= 1 if strKernel[intStop] == ')' else 0
|
190 |
+
|
191 |
+
if intParentheses == 0:
|
192 |
+
break
|
193 |
+
# end
|
194 |
+
|
195 |
+
intStop += 1
|
196 |
+
# end
|
197 |
+
|
198 |
+
intArgs = int(objMatch.group(2))
|
199 |
+
strArgs = strKernel[intStart:intStop].split(',')
|
200 |
+
|
201 |
+
assert(intArgs == len(strArgs) - 1)
|
202 |
+
|
203 |
+
strTensor = strArgs[0]
|
204 |
+
intStrides = objVariables[strTensor].stride()
|
205 |
+
|
206 |
+
strIndex = []
|
207 |
+
|
208 |
+
for intArg in range(intArgs):
|
209 |
+
strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
|
210 |
+
# end
|
211 |
+
|
212 |
+
strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']')
|
213 |
+
# end
|
214 |
+
|
215 |
+
objCudacache[strKey] = {
|
216 |
+
'strFunction': strFunction,
|
217 |
+
'strKernel': strKernel
|
218 |
+
}
|
219 |
+
# end
|
220 |
+
|
221 |
+
return strKey
|
222 |
+
# end
|
223 |
+
|
224 |
+
|
225 |
+
@cupy.memoize(for_each_device=True)
|
226 |
+
def cuda_launch(strKey:str):
|
227 |
+
if 'CUDA_HOME' not in os.environ:
|
228 |
+
os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path()
|
229 |
+
# end
|
230 |
+
|
231 |
+
return cupy.RawKernel(objCudacache[strKey]['strKernel'], objCudacache[strKey]['strFunction'])
|
232 |
+
# end
|
233 |
+
|
234 |
+
|
235 |
+
##########################################################
|
236 |
+
|
237 |
+
|
238 |
+
def softsplat(tenIn:torch.Tensor, tenFlow:torch.Tensor,
|
239 |
+
tenMetric:torch.Tensor, strMode:str, tenoutH=None, tenoutW=None):
|
240 |
+
assert(strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft'])
|
241 |
+
|
242 |
+
if strMode == 'sum': assert(tenMetric is None)
|
243 |
+
if strMode == 'avg': assert(tenMetric is None)
|
244 |
+
if strMode.split('-')[0] == 'linear': assert(tenMetric is not None)
|
245 |
+
if strMode.split('-')[0] == 'soft': assert(tenMetric is not None)
|
246 |
+
|
247 |
+
if strMode == 'avg':
|
248 |
+
tenIn = torch.cat([tenIn, tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]])], 1)
|
249 |
+
|
250 |
+
elif strMode.split('-')[0] == 'linear':
|
251 |
+
tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1)
|
252 |
+
|
253 |
+
elif strMode.split('-')[0] == 'soft':
|
254 |
+
tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1)
|
255 |
+
|
256 |
+
# end
|
257 |
+
|
258 |
+
tenOut = softsplat_func.apply(tenIn, tenFlow, tenoutH, tenoutW)
|
259 |
+
|
260 |
+
if strMode.split('-')[0] in ['avg', 'linear', 'soft']:
|
261 |
+
tenNormalize = tenOut[:, -1:, :, :]
|
262 |
+
|
263 |
+
if len(strMode.split('-')) == 1:
|
264 |
+
tenNormalize = tenNormalize + 0.0000001
|
265 |
+
|
266 |
+
elif strMode.split('-')[1] == 'addeps':
|
267 |
+
tenNormalize = tenNormalize + 0.0000001
|
268 |
+
|
269 |
+
elif strMode.split('-')[1] == 'zeroeps':
|
270 |
+
tenNormalize[tenNormalize == 0.0] = 1.0
|
271 |
+
|
272 |
+
elif strMode.split('-')[1] == 'clipeps':
|
273 |
+
tenNormalize = tenNormalize.clip(0.0000001, None)
|
274 |
+
|
275 |
+
# end
|
276 |
+
tenOut = tenOut[:, :-1, :, :] / tenNormalize
|
277 |
+
# end
|
278 |
+
|
279 |
+
return tenOut
|
280 |
+
# end
|
281 |
+
|
282 |
+
|
283 |
+
class softsplat_func(torch.autograd.Function):
|
284 |
+
@staticmethod
|
285 |
+
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
286 |
+
def forward(self, tenIn, tenFlow, H=None, W=None):
|
287 |
+
if H is None:
|
288 |
+
tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]])
|
289 |
+
else:
|
290 |
+
tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], H, W])
|
291 |
+
|
292 |
+
if tenIn.is_cuda == True:
|
293 |
+
cuda_launch(cuda_kernel('softsplat_out', '''
|
294 |
+
extern "C" __global__ void __launch_bounds__(512) softsplat_out(
|
295 |
+
const int n,
|
296 |
+
const {{type}}* __restrict__ tenIn,
|
297 |
+
const {{type}}* __restrict__ tenFlow,
|
298 |
+
{{type}}* __restrict__ tenOut
|
299 |
+
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
300 |
+
const int intN = ( intIndex / SIZE_3(tenIn) / SIZE_2(tenIn) / SIZE_1(tenIn) ) % SIZE_0(tenIn);
|
301 |
+
const int intC = ( intIndex / SIZE_3(tenIn) / SIZE_2(tenIn) ) % SIZE_1(tenIn);
|
302 |
+
const int intY = ( intIndex / SIZE_3(tenIn) ) % SIZE_2(tenIn);
|
303 |
+
const int intX = ( intIndex ) % SIZE_3(tenIn);
|
304 |
+
|
305 |
+
assert(SIZE_1(tenFlow) == 2);
|
306 |
+
|
307 |
+
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
|
308 |
+
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
|
309 |
+
|
310 |
+
if (isfinite(fltX) == false) { return; }
|
311 |
+
if (isfinite(fltY) == false) { return; }
|
312 |
+
|
313 |
+
{{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX);
|
314 |
+
|
315 |
+
int intNorthwestX = (int) (floor(fltX));
|
316 |
+
int intNorthwestY = (int) (floor(fltY));
|
317 |
+
int intNortheastX = intNorthwestX + 1;
|
318 |
+
int intNortheastY = intNorthwestY;
|
319 |
+
int intSouthwestX = intNorthwestX;
|
320 |
+
int intSouthwestY = intNorthwestY + 1;
|
321 |
+
int intSoutheastX = intNorthwestX + 1;
|
322 |
+
int intSoutheastY = intNorthwestY + 1;
|
323 |
+
|
324 |
+
{{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
|
325 |
+
{{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
|
326 |
+
{{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
|
327 |
+
{{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
|
328 |
+
|
329 |
+
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) {
|
330 |
+
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest);
|
331 |
+
}
|
332 |
+
|
333 |
+
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) {
|
334 |
+
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast);
|
335 |
+
}
|
336 |
+
|
337 |
+
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) {
|
338 |
+
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest);
|
339 |
+
}
|
340 |
+
|
341 |
+
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) {
|
342 |
+
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast);
|
343 |
+
}
|
344 |
+
} }
|
345 |
+
''', {
|
346 |
+
'tenIn': tenIn,
|
347 |
+
'tenFlow': tenFlow,
|
348 |
+
'tenOut': tenOut
|
349 |
+
}))(
|
350 |
+
grid=tuple([int((tenIn.nelement() + 512 - 1) / 512), 1, 1]),
|
351 |
+
block=tuple([512, 1, 1]),
|
352 |
+
args=[cuda_int32(tenOut.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()],
|
353 |
+
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
|
354 |
+
)
|
355 |
+
|
356 |
+
elif tenIn.is_cuda != True:
|
357 |
+
assert(False)
|
358 |
+
|
359 |
+
# end
|
360 |
+
|
361 |
+
self.save_for_backward(tenIn, tenFlow)
|
362 |
+
|
363 |
+
return tenOut
|
364 |
+
# end
|
365 |
+
|
366 |
+
@staticmethod
|
367 |
+
@torch.cuda.amp.custom_bwd
|
368 |
+
def backward(self, tenOutgrad):
|
369 |
+
tenIn, tenFlow = self.saved_tensors
|
370 |
+
|
371 |
+
tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True)
|
372 |
+
|
373 |
+
tenIngrad = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) if self.needs_input_grad[0] == True else None
|
374 |
+
tenFlowgrad = tenFlow.new_zeros([tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]) if self.needs_input_grad[1] == True else None
|
375 |
+
Hgrad = None
|
376 |
+
Wgrad = None
|
377 |
+
|
378 |
+
if tenIngrad is not None:
|
379 |
+
cuda_launch(cuda_kernel('softsplat_ingrad', '''
|
380 |
+
extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad(
|
381 |
+
const int n,
|
382 |
+
const {{type}}* __restrict__ tenIn,
|
383 |
+
const {{type}}* __restrict__ tenFlow,
|
384 |
+
const {{type}}* __restrict__ tenOutgrad,
|
385 |
+
{{type}}* __restrict__ tenIngrad,
|
386 |
+
{{type}}* __restrict__ tenFlowgrad
|
387 |
+
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
388 |
+
const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad);
|
389 |
+
const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad);
|
390 |
+
const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad);
|
391 |
+
const int intX = ( intIndex ) % SIZE_3(tenIngrad);
|
392 |
+
|
393 |
+
assert(SIZE_1(tenFlow) == 2);
|
394 |
+
|
395 |
+
{{type}} fltIngrad = 0.0f;
|
396 |
+
|
397 |
+
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
|
398 |
+
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
|
399 |
+
|
400 |
+
if (isfinite(fltX) == false) { return; }
|
401 |
+
if (isfinite(fltY) == false) { return; }
|
402 |
+
|
403 |
+
int intNorthwestX = (int) (floor(fltX));
|
404 |
+
int intNorthwestY = (int) (floor(fltY));
|
405 |
+
int intNortheastX = intNorthwestX + 1;
|
406 |
+
int intNortheastY = intNorthwestY;
|
407 |
+
int intSouthwestX = intNorthwestX;
|
408 |
+
int intSouthwestY = intNorthwestY + 1;
|
409 |
+
int intSoutheastX = intNorthwestX + 1;
|
410 |
+
int intSoutheastY = intNorthwestY + 1;
|
411 |
+
|
412 |
+
{{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
|
413 |
+
{{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
|
414 |
+
{{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
|
415 |
+
{{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
|
416 |
+
|
417 |
+
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
|
418 |
+
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest;
|
419 |
+
}
|
420 |
+
|
421 |
+
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
|
422 |
+
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast;
|
423 |
+
}
|
424 |
+
|
425 |
+
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
|
426 |
+
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest;
|
427 |
+
}
|
428 |
+
|
429 |
+
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
|
430 |
+
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast;
|
431 |
+
}
|
432 |
+
|
433 |
+
tenIngrad[intIndex] = fltIngrad;
|
434 |
+
} }
|
435 |
+
''', {
|
436 |
+
'tenIn': tenIn,
|
437 |
+
'tenFlow': tenFlow,
|
438 |
+
'tenOutgrad': tenOutgrad,
|
439 |
+
'tenIngrad': tenIngrad,
|
440 |
+
'tenFlowgrad': tenFlowgrad
|
441 |
+
}))(
|
442 |
+
grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]),
|
443 |
+
block=tuple([512, 1, 1]),
|
444 |
+
args=[cuda_int32(tenIngrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), tenIngrad.data_ptr(), None],
|
445 |
+
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
|
446 |
+
)
|
447 |
+
# end
|
448 |
+
|
449 |
+
if tenFlowgrad is not None:
|
450 |
+
cuda_launch(cuda_kernel('softsplat_flowgrad', '''
|
451 |
+
extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad(
|
452 |
+
const int n,
|
453 |
+
const {{type}}* __restrict__ tenIn,
|
454 |
+
const {{type}}* __restrict__ tenFlow,
|
455 |
+
const {{type}}* __restrict__ tenOutgrad,
|
456 |
+
{{type}}* __restrict__ tenIngrad,
|
457 |
+
{{type}}* __restrict__ tenFlowgrad
|
458 |
+
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
459 |
+
const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad);
|
460 |
+
const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad);
|
461 |
+
const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad);
|
462 |
+
const int intX = ( intIndex ) % SIZE_3(tenFlowgrad);
|
463 |
+
|
464 |
+
assert(SIZE_1(tenFlow) == 2);
|
465 |
+
|
466 |
+
{{type}} fltFlowgrad = 0.0f;
|
467 |
+
|
468 |
+
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
|
469 |
+
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
|
470 |
+
|
471 |
+
if (isfinite(fltX) == false) { return; }
|
472 |
+
if (isfinite(fltY) == false) { return; }
|
473 |
+
|
474 |
+
int intNorthwestX = (int) (floor(fltX));
|
475 |
+
int intNorthwestY = (int) (floor(fltY));
|
476 |
+
int intNortheastX = intNorthwestX + 1;
|
477 |
+
int intNortheastY = intNorthwestY;
|
478 |
+
int intSouthwestX = intNorthwestX;
|
479 |
+
int intSouthwestY = intNorthwestY + 1;
|
480 |
+
int intSoutheastX = intNorthwestX + 1;
|
481 |
+
int intSoutheastY = intNorthwestY + 1;
|
482 |
+
|
483 |
+
{{type}} fltNorthwest = 0.0f;
|
484 |
+
{{type}} fltNortheast = 0.0f;
|
485 |
+
{{type}} fltSouthwest = 0.0f;
|
486 |
+
{{type}} fltSoutheast = 0.0f;
|
487 |
+
|
488 |
+
if (intC == 0) {
|
489 |
+
fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY);
|
490 |
+
fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY);
|
491 |
+
fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY));
|
492 |
+
fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY));
|
493 |
+
|
494 |
+
} else if (intC == 1) {
|
495 |
+
fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f));
|
496 |
+
fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f));
|
497 |
+
fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f));
|
498 |
+
fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f));
|
499 |
+
|
500 |
+
}
|
501 |
+
|
502 |
+
for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) {
|
503 |
+
{{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX);
|
504 |
+
|
505 |
+
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
|
506 |
+
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest;
|
507 |
+
}
|
508 |
+
|
509 |
+
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
|
510 |
+
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast;
|
511 |
+
}
|
512 |
+
|
513 |
+
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
|
514 |
+
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest;
|
515 |
+
}
|
516 |
+
|
517 |
+
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
|
518 |
+
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast;
|
519 |
+
}
|
520 |
+
}
|
521 |
+
|
522 |
+
tenFlowgrad[intIndex] = fltFlowgrad;
|
523 |
+
} }
|
524 |
+
''', {
|
525 |
+
'tenIn': tenIn,
|
526 |
+
'tenFlow': tenFlow,
|
527 |
+
'tenOutgrad': tenOutgrad,
|
528 |
+
'tenIngrad': tenIngrad,
|
529 |
+
'tenFlowgrad': tenFlowgrad
|
530 |
+
}))(
|
531 |
+
grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]),
|
532 |
+
block=tuple([512, 1, 1]),
|
533 |
+
args=[cuda_int32(tenFlowgrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), None, tenFlowgrad.data_ptr()],
|
534 |
+
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
|
535 |
+
)
|
536 |
+
# end
|
537 |
+
return tenIngrad, tenFlowgrad, Hgrad, Wgrad
|
538 |
+
# end
|
539 |
+
# end
|
models/spatracker/models/core/spatracker/spatracker.py
ADDED
@@ -0,0 +1,732 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from easydict import EasyDict as edict
|
10 |
+
from einops import rearrange
|
11 |
+
from sklearn.cluster import SpectralClustering
|
12 |
+
from models.spatracker.models.core.spatracker.blocks import Lie
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
import cv2
|
15 |
+
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from models.spatracker.models.core.spatracker.blocks import (
|
18 |
+
BasicEncoder,
|
19 |
+
CorrBlock,
|
20 |
+
EUpdateFormer,
|
21 |
+
FusionFormer,
|
22 |
+
pix2cam,
|
23 |
+
cam2pix,
|
24 |
+
edgeMat,
|
25 |
+
VitEncoder,
|
26 |
+
DPTEnc,
|
27 |
+
Dinov2
|
28 |
+
)
|
29 |
+
|
30 |
+
from models.spatracker.models.core.spatracker.feature_net import (
|
31 |
+
LocalSoftSplat
|
32 |
+
)
|
33 |
+
|
34 |
+
from models.spatracker.models.core.model_utils import (
|
35 |
+
meshgrid2d, bilinear_sample2d, smart_cat, sample_features5d, vis_PCA
|
36 |
+
)
|
37 |
+
from models.spatracker.models.core.embeddings import (
|
38 |
+
get_2d_embedding,
|
39 |
+
get_3d_embedding,
|
40 |
+
get_1d_sincos_pos_embed_from_grid,
|
41 |
+
get_2d_sincos_pos_embed,
|
42 |
+
get_3d_sincos_pos_embed_from_grid,
|
43 |
+
Embedder_Fourier,
|
44 |
+
)
|
45 |
+
import numpy as np
|
46 |
+
from models.spatracker.models.core.spatracker.softsplat import softsplat
|
47 |
+
|
48 |
+
torch.manual_seed(0)
|
49 |
+
|
50 |
+
|
51 |
+
def get_points_on_a_grid(grid_size, interp_shape,
|
52 |
+
grid_center=(0, 0), device="cuda"):
|
53 |
+
if grid_size == 1:
|
54 |
+
return torch.tensor([interp_shape[1] / 2,
|
55 |
+
interp_shape[0] / 2], device=device)[
|
56 |
+
None, None
|
57 |
+
]
|
58 |
+
|
59 |
+
grid_y, grid_x = meshgrid2d(
|
60 |
+
1, grid_size, grid_size, stack=False, norm=False, device=device
|
61 |
+
)
|
62 |
+
step = interp_shape[1] // 64
|
63 |
+
if grid_center[0] != 0 or grid_center[1] != 0:
|
64 |
+
grid_y = grid_y - grid_size / 2.0
|
65 |
+
grid_x = grid_x - grid_size / 2.0
|
66 |
+
grid_y = step + grid_y.reshape(1, -1) / float(grid_size - 1) * (
|
67 |
+
interp_shape[0] - step * 2
|
68 |
+
)
|
69 |
+
grid_x = step + grid_x.reshape(1, -1) / float(grid_size - 1) * (
|
70 |
+
interp_shape[1] - step * 2
|
71 |
+
)
|
72 |
+
|
73 |
+
grid_y = grid_y + grid_center[0]
|
74 |
+
grid_x = grid_x + grid_center[1]
|
75 |
+
xy = torch.stack([grid_x, grid_y], dim=-1).to(device)
|
76 |
+
return xy
|
77 |
+
|
78 |
+
|
79 |
+
def sample_pos_embed(grid_size, embed_dim, coords):
|
80 |
+
if coords.shape[-1] == 2:
|
81 |
+
pos_embed = get_2d_sincos_pos_embed(embed_dim=embed_dim,
|
82 |
+
grid_size=grid_size)
|
83 |
+
pos_embed = (
|
84 |
+
torch.from_numpy(pos_embed)
|
85 |
+
.reshape(grid_size[0], grid_size[1], embed_dim)
|
86 |
+
.float()
|
87 |
+
.unsqueeze(0)
|
88 |
+
.to(coords.device)
|
89 |
+
)
|
90 |
+
sampled_pos_embed = bilinear_sample2d(
|
91 |
+
pos_embed.permute(0, 3, 1, 2),
|
92 |
+
coords[:, 0, :, 0], coords[:, 0, :, 1]
|
93 |
+
)
|
94 |
+
elif coords.shape[-1] == 3:
|
95 |
+
sampled_pos_embed = get_3d_sincos_pos_embed_from_grid(
|
96 |
+
embed_dim, coords[:, :1, ...]
|
97 |
+
).float()[:,0,...].permute(0, 2, 1)
|
98 |
+
|
99 |
+
return sampled_pos_embed
|
100 |
+
|
101 |
+
|
102 |
+
class SpaTracker(nn.Module):
|
103 |
+
def __init__(
|
104 |
+
self,
|
105 |
+
S=8,
|
106 |
+
stride=8,
|
107 |
+
add_space_attn=True,
|
108 |
+
num_heads=8,
|
109 |
+
hidden_size=384,
|
110 |
+
space_depth=12,
|
111 |
+
time_depth=12,
|
112 |
+
args=edict({})
|
113 |
+
):
|
114 |
+
super(SpaTracker, self).__init__()
|
115 |
+
|
116 |
+
# step1: config the arch of the model
|
117 |
+
self.args=args
|
118 |
+
# step1.1: config the default value of the model
|
119 |
+
if getattr(args, "depth_color", None) == None:
|
120 |
+
self.args.depth_color = False
|
121 |
+
if getattr(args, "if_ARAP", None) == None:
|
122 |
+
self.args.if_ARAP = True
|
123 |
+
if getattr(args, "flash_attn", None) == None:
|
124 |
+
self.args.flash_attn = True
|
125 |
+
if getattr(args, "backbone", None) == None:
|
126 |
+
self.args.backbone = "CNN"
|
127 |
+
if getattr(args, "Nblock", None) == None:
|
128 |
+
self.args.Nblock = 0
|
129 |
+
if getattr(args, "Embed3D", None) == None:
|
130 |
+
self.args.Embed3D = True
|
131 |
+
|
132 |
+
# step1.2: config the model parameters
|
133 |
+
self.S = S
|
134 |
+
self.stride = stride
|
135 |
+
self.hidden_dim = 256
|
136 |
+
self.latent_dim = latent_dim = 128
|
137 |
+
self.b_latent_dim = self.latent_dim//3
|
138 |
+
self.corr_levels = 4
|
139 |
+
self.corr_radius = 3
|
140 |
+
self.add_space_attn = add_space_attn
|
141 |
+
self.lie = Lie()
|
142 |
+
|
143 |
+
# step2: config the model components
|
144 |
+
# @Encoder
|
145 |
+
self.fnet = BasicEncoder(input_dim=3,
|
146 |
+
output_dim=self.latent_dim, norm_fn="instance", dropout=0,
|
147 |
+
stride=stride, Embed3D=False
|
148 |
+
)
|
149 |
+
|
150 |
+
# conv head for the tri-plane features
|
151 |
+
self.headyz = nn.Sequential(
|
152 |
+
nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1),
|
153 |
+
nn.ReLU(inplace=True),
|
154 |
+
nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1))
|
155 |
+
|
156 |
+
self.headxz = nn.Sequential(
|
157 |
+
nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1),
|
158 |
+
nn.ReLU(inplace=True),
|
159 |
+
nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1))
|
160 |
+
|
161 |
+
# @UpdateFormer
|
162 |
+
self.updateformer = EUpdateFormer(
|
163 |
+
space_depth=space_depth,
|
164 |
+
time_depth=time_depth,
|
165 |
+
input_dim=456,
|
166 |
+
hidden_size=hidden_size,
|
167 |
+
num_heads=num_heads,
|
168 |
+
output_dim=latent_dim + 3,
|
169 |
+
mlp_ratio=4.0,
|
170 |
+
add_space_attn=add_space_attn,
|
171 |
+
flash=getattr(self.args, "flash_attn", True)
|
172 |
+
)
|
173 |
+
self.support_features = torch.zeros(100, 384).to("cuda") + 0.1
|
174 |
+
|
175 |
+
self.norm = nn.GroupNorm(1, self.latent_dim)
|
176 |
+
|
177 |
+
self.ffeat_updater = nn.Sequential(
|
178 |
+
nn.Linear(self.latent_dim, self.latent_dim),
|
179 |
+
nn.GELU(),
|
180 |
+
)
|
181 |
+
self.ffeatyz_updater = nn.Sequential(
|
182 |
+
nn.Linear(self.latent_dim, self.latent_dim),
|
183 |
+
nn.GELU(),
|
184 |
+
)
|
185 |
+
self.ffeatxz_updater = nn.Sequential(
|
186 |
+
nn.Linear(self.latent_dim, self.latent_dim),
|
187 |
+
nn.GELU(),
|
188 |
+
)
|
189 |
+
|
190 |
+
#TODO @NeuralArap: optimize the arap
|
191 |
+
self.embed_traj = Embedder_Fourier(
|
192 |
+
input_dim=5, max_freq_log2=5.0, N_freqs=3, include_input=True
|
193 |
+
)
|
194 |
+
self.embed3d = Embedder_Fourier(
|
195 |
+
input_dim=3, max_freq_log2=10.0, N_freqs=10, include_input=True
|
196 |
+
)
|
197 |
+
self.embedConv = nn.Conv2d(self.latent_dim+63,
|
198 |
+
self.latent_dim, 3, padding=1)
|
199 |
+
|
200 |
+
# @Vis_predictor
|
201 |
+
self.vis_predictor = nn.Sequential(
|
202 |
+
nn.Linear(128, 1),
|
203 |
+
)
|
204 |
+
|
205 |
+
self.embedProj = nn.Linear(63, 456)
|
206 |
+
self.zeroMLPflow = nn.Linear(195, 130)
|
207 |
+
|
208 |
+
def prepare_track(self, rgbds, queries):
|
209 |
+
"""
|
210 |
+
NOTE:
|
211 |
+
Normalized the rgbs and sorted the queries via their first appeared time
|
212 |
+
Args:
|
213 |
+
rgbds: the input rgbd images (B T 4 H W)
|
214 |
+
queries: the input queries (B N 4)
|
215 |
+
Return:
|
216 |
+
rgbds: the normalized rgbds (B T 4 H W)
|
217 |
+
queries: the sorted queries (B N 4)
|
218 |
+
track_mask:
|
219 |
+
"""
|
220 |
+
assert (rgbds.shape[2]==4) and (queries.shape[2]==4)
|
221 |
+
#Step1: normalize the rgbs input
|
222 |
+
device = rgbds.device
|
223 |
+
rgbds[:, :, :3, ...] = 2 * (rgbds[:, :, :3, ...] / 255.0) - 1.0
|
224 |
+
B, T, C, H, W = rgbds.shape
|
225 |
+
B, N, __ = queries.shape
|
226 |
+
self.traj_e = torch.zeros((B, T, N, 3), device=device)
|
227 |
+
self.vis_e = torch.zeros((B, T, N), device=device)
|
228 |
+
|
229 |
+
#Step2: sort the points via their first appeared time
|
230 |
+
first_positive_inds = queries[0, :, 0].long()
|
231 |
+
__, sort_inds = torch.sort(first_positive_inds, dim=0, descending=False)
|
232 |
+
inv_sort_inds = torch.argsort(sort_inds, dim=0)
|
233 |
+
first_positive_sorted_inds = first_positive_inds[sort_inds]
|
234 |
+
# check if can be inverse
|
235 |
+
assert torch.allclose(
|
236 |
+
first_positive_inds, first_positive_inds[sort_inds][inv_sort_inds]
|
237 |
+
)
|
238 |
+
|
239 |
+
# filter those points never appear points during 1 - T
|
240 |
+
ind_array = torch.arange(T, device=device)
|
241 |
+
ind_array = ind_array[None, :, None].repeat(B, 1, N)
|
242 |
+
track_mask = (ind_array >=
|
243 |
+
first_positive_inds[None, None, :]).unsqueeze(-1)
|
244 |
+
|
245 |
+
# scale the coords_init
|
246 |
+
coords_init = queries[:, :, 1:].reshape(B, 1, N, 3).repeat(
|
247 |
+
1, self.S, 1, 1
|
248 |
+
)
|
249 |
+
coords_init[..., :2] /= float(self.stride)
|
250 |
+
|
251 |
+
#Step3: initial the regular grid
|
252 |
+
gridx = torch.linspace(0, W//self.stride - 1, W//self.stride)
|
253 |
+
gridy = torch.linspace(0, H//self.stride - 1, H//self.stride)
|
254 |
+
gridx, gridy = torch.meshgrid(gridx, gridy)
|
255 |
+
gridxy = torch.stack([gridx, gridy], dim=-1).to(rgbds.device).permute(
|
256 |
+
2, 1, 0
|
257 |
+
)
|
258 |
+
vis_init = torch.ones((B, self.S, N, 1), device=device).float() * 10
|
259 |
+
|
260 |
+
# Step4: initial traj for neural arap
|
261 |
+
T_series = torch.linspace(0, 5, T).reshape(1, T, 1 , 1).cuda() # 1 T 1 1
|
262 |
+
T_series = T_series.repeat(B, 1, N, 1)
|
263 |
+
# get the 3d traj in the camera coordinates
|
264 |
+
intr_init = self.intrs[:,queries[0,:,0].long()]
|
265 |
+
Traj_series = pix2cam(queries[:,:,None,1:].double(), intr_init.double())
|
266 |
+
#torch.inverse(intr_init.double())@queries[:,:,1:,None].double() # B N 3 1
|
267 |
+
Traj_series = Traj_series.repeat(1, 1, T, 1).permute(0, 2, 1, 3).float()
|
268 |
+
Traj_series = torch.cat([T_series, Traj_series], dim=-1)
|
269 |
+
# get the indicator for the neural arap
|
270 |
+
Traj_mask = -1e2*torch.ones_like(T_series)
|
271 |
+
Traj_series = torch.cat([Traj_series, Traj_mask], dim=-1)
|
272 |
+
|
273 |
+
return (
|
274 |
+
rgbds,
|
275 |
+
first_positive_inds,
|
276 |
+
first_positive_sorted_inds,
|
277 |
+
sort_inds, inv_sort_inds,
|
278 |
+
track_mask, gridxy, coords_init[..., sort_inds, :].clone(),
|
279 |
+
vis_init, Traj_series[..., sort_inds, :].clone()
|
280 |
+
)
|
281 |
+
|
282 |
+
def sample_trifeat(self, t,
|
283 |
+
coords,
|
284 |
+
featMapxy,
|
285 |
+
featMapyz,
|
286 |
+
featMapxz):
|
287 |
+
"""
|
288 |
+
Sample the features from the 5D triplane feature map 3*(B S C H W)
|
289 |
+
Args:
|
290 |
+
t: the time index
|
291 |
+
coords: the coordinates of the points B S N 3
|
292 |
+
featMapxy: the feature map B S C Hx Wy
|
293 |
+
featMapyz: the feature map B S C Hy Wz
|
294 |
+
featMapxz: the feature map B S C Hx Wz
|
295 |
+
"""
|
296 |
+
# get xy_t yz_t xz_t
|
297 |
+
queried_t = t.reshape(1, 1, -1, 1)
|
298 |
+
xy_t = torch.cat(
|
299 |
+
[queried_t, coords[..., [0,1]]],
|
300 |
+
dim=-1
|
301 |
+
)
|
302 |
+
yz_t = torch.cat(
|
303 |
+
[queried_t, coords[..., [1, 2]]],
|
304 |
+
dim=-1
|
305 |
+
)
|
306 |
+
xz_t = torch.cat(
|
307 |
+
[queried_t, coords[..., [0, 2]]],
|
308 |
+
dim=-1
|
309 |
+
)
|
310 |
+
featxy_init = sample_features5d(featMapxy, xy_t)
|
311 |
+
|
312 |
+
featyz_init = sample_features5d(featMapyz, yz_t)
|
313 |
+
featxz_init = sample_features5d(featMapxz, xz_t)
|
314 |
+
|
315 |
+
featxy_init = featxy_init.repeat(1, self.S, 1, 1)
|
316 |
+
featyz_init = featyz_init.repeat(1, self.S, 1, 1)
|
317 |
+
featxz_init = featxz_init.repeat(1, self.S, 1, 1)
|
318 |
+
|
319 |
+
return featxy_init, featyz_init, featxz_init
|
320 |
+
|
321 |
+
def neural_arap(self, coords, Traj_arap, intrs_S, T_mark):
|
322 |
+
""" calculate the ARAP embedding and offset
|
323 |
+
Args:
|
324 |
+
coords: the coordinates of the current points 1 S N' 3
|
325 |
+
Traj_arap: the trajectory of the points 1 T N' 5
|
326 |
+
intrs_S: the camera intrinsics B S 3 3
|
327 |
+
|
328 |
+
"""
|
329 |
+
coords_out = coords.clone()
|
330 |
+
coords_out[..., :2] *= float(self.stride)
|
331 |
+
coords_out[..., 2] = coords_out[..., 2]/self.Dz
|
332 |
+
coords_out[..., 2] = coords_out[..., 2]*(self.d_far-self.d_near) + self.d_near
|
333 |
+
intrs_S = intrs_S[:, :, None, ...].repeat(1, 1, coords_out.shape[2], 1, 1)
|
334 |
+
B, S, N, D = coords_out.shape
|
335 |
+
if S != intrs_S.shape[1]:
|
336 |
+
intrs_S = torch.cat(
|
337 |
+
[intrs_S, intrs_S[:, -1:].repeat(1, S - intrs_S.shape[1],1,1,1)], dim=1
|
338 |
+
)
|
339 |
+
T_mark = torch.cat(
|
340 |
+
[T_mark, T_mark[:, -1:].repeat(1, S - T_mark.shape[1],1)], dim=1
|
341 |
+
)
|
342 |
+
xyz_ = pix2cam(coords_out.double(), intrs_S.double()[:,:,0])
|
343 |
+
xyz_ = xyz_.float()
|
344 |
+
xyz_embed = torch.cat([T_mark[...,None], xyz_,
|
345 |
+
torch.zeros_like(T_mark[...,None])], dim=-1)
|
346 |
+
|
347 |
+
xyz_embed = self.embed_traj(xyz_embed)
|
348 |
+
Traj_arap_embed = self.embed_traj(Traj_arap)
|
349 |
+
d_xyz,traj_feat = self.arapFormer(xyz_embed, Traj_arap_embed)
|
350 |
+
# update in camera coordinate
|
351 |
+
xyz_ = xyz_ + d_xyz.clamp(-5, 5)
|
352 |
+
# project back to the image plane
|
353 |
+
coords_out = cam2pix(xyz_.double(), intrs_S[:,:,0].double()).float()
|
354 |
+
# resize back
|
355 |
+
coords_out[..., :2] /= float(self.stride)
|
356 |
+
coords_out[..., 2] = (coords_out[..., 2] - self.d_near)/(self.d_far-self.d_near)
|
357 |
+
coords_out[..., 2] *= self.Dz
|
358 |
+
|
359 |
+
return xyz_, coords_out, traj_feat
|
360 |
+
|
361 |
+
def gradient_arap(self, coords, aff_avg=None, aff_std=None, aff_f_sg=None,
|
362 |
+
iter=0, iter_num=4, neigh_idx=None, intr=None, msk_track=None):
|
363 |
+
with torch.enable_grad():
|
364 |
+
coords.requires_grad_(True)
|
365 |
+
y = self.ARAP_ln(coords, aff_f_sg=aff_f_sg, neigh_idx=neigh_idx,
|
366 |
+
iter=iter, iter_num=iter_num, intr=intr,msk_track=msk_track)
|
367 |
+
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
|
368 |
+
gradients = torch.autograd.grad(
|
369 |
+
outputs=y,
|
370 |
+
inputs=coords,
|
371 |
+
grad_outputs=d_output,
|
372 |
+
create_graph=True,
|
373 |
+
retain_graph=True,
|
374 |
+
only_inputs=True, allow_unused=True)[0]
|
375 |
+
|
376 |
+
return gradients.detach()
|
377 |
+
|
378 |
+
def forward_iteration(
|
379 |
+
self,
|
380 |
+
fmapXY,
|
381 |
+
fmapYZ,
|
382 |
+
fmapXZ,
|
383 |
+
coords_init,
|
384 |
+
feat_init=None,
|
385 |
+
vis_init=None,
|
386 |
+
track_mask=None,
|
387 |
+
iters=4,
|
388 |
+
intrs_S=None,
|
389 |
+
):
|
390 |
+
B, S_init, N, D = coords_init.shape
|
391 |
+
assert D == 3
|
392 |
+
assert B == 1
|
393 |
+
B, S, __, H8, W8 = fmapXY.shape
|
394 |
+
device = fmapXY.device
|
395 |
+
|
396 |
+
if S_init < S:
|
397 |
+
coords = torch.cat(
|
398 |
+
[coords_init, coords_init[:, -1].repeat(1, S - S_init, 1, 1)],
|
399 |
+
dim=1
|
400 |
+
)
|
401 |
+
vis_init = torch.cat(
|
402 |
+
[vis_init, vis_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1
|
403 |
+
)
|
404 |
+
intrs_S = torch.cat(
|
405 |
+
[intrs_S, intrs_S[:, -1].repeat(1, S - S_init, 1, 1)], dim=1
|
406 |
+
)
|
407 |
+
else:
|
408 |
+
coords = coords_init.clone()
|
409 |
+
|
410 |
+
fcorr_fnXY = CorrBlock(
|
411 |
+
fmapXY, num_levels=self.corr_levels, radius=self.corr_radius
|
412 |
+
)
|
413 |
+
fcorr_fnYZ = CorrBlock(
|
414 |
+
fmapYZ, num_levels=self.corr_levels, radius=self.corr_radius
|
415 |
+
)
|
416 |
+
fcorr_fnXZ = CorrBlock(
|
417 |
+
fmapXZ, num_levels=self.corr_levels, radius=self.corr_radius
|
418 |
+
)
|
419 |
+
|
420 |
+
ffeats = torch.split(feat_init.clone(), dim=-1, split_size_or_sections=1)
|
421 |
+
ffeats = [f.squeeze(-1) for f in ffeats]
|
422 |
+
|
423 |
+
times_ = torch.linspace(0, S - 1, S).reshape(1, S, 1)
|
424 |
+
pos_embed = sample_pos_embed(
|
425 |
+
grid_size=(H8, W8),
|
426 |
+
embed_dim=456,
|
427 |
+
coords=coords[..., :2],
|
428 |
+
)
|
429 |
+
pos_embed = rearrange(pos_embed, "b e n -> (b n) e").unsqueeze(1)
|
430 |
+
|
431 |
+
times_embed = (
|
432 |
+
torch.from_numpy(get_1d_sincos_pos_embed_from_grid(456, times_[0]))[None]
|
433 |
+
.repeat(B, 1, 1)
|
434 |
+
.float()
|
435 |
+
.to(device)
|
436 |
+
)
|
437 |
+
coord_predictions = []
|
438 |
+
attn_predictions = []
|
439 |
+
Rot_ln = 0
|
440 |
+
support_feat = self.support_features
|
441 |
+
|
442 |
+
for __ in range(iters):
|
443 |
+
coords = coords.detach()
|
444 |
+
# if self.args.if_ARAP == True:
|
445 |
+
# # refine the track with arap
|
446 |
+
# xyz_pred, coords, flows_cat0 = self.neural_arap(coords.detach(),
|
447 |
+
# Traj_arap.detach(),
|
448 |
+
# intrs_S, T_mark)
|
449 |
+
with torch.no_grad():
|
450 |
+
fcorrsXY = fcorr_fnXY.corr_sample(ffeats[0], coords[..., :2])
|
451 |
+
fcorrsYZ = fcorr_fnYZ.corr_sample(ffeats[1], coords[..., [1,2]])
|
452 |
+
fcorrsXZ = fcorr_fnXZ.corr_sample(ffeats[2], coords[..., [0,2]])
|
453 |
+
# fcorrs = fcorrsXY
|
454 |
+
fcorrs = fcorrsXY + fcorrsYZ + fcorrsXZ
|
455 |
+
LRR = fcorrs.shape[3]
|
456 |
+
fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, LRR)
|
457 |
+
|
458 |
+
flows_ = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 3)
|
459 |
+
flows_cat = get_3d_embedding(flows_, 64, cat_coords=True)
|
460 |
+
flows_cat = self.zeroMLPflow(flows_cat)
|
461 |
+
|
462 |
+
|
463 |
+
ffeats_xy = ffeats[0].permute(0,
|
464 |
+
2, 1, 3).reshape(B * N, S, self.latent_dim)
|
465 |
+
ffeats_yz = ffeats[1].permute(0,
|
466 |
+
2, 1, 3).reshape(B * N, S, self.latent_dim)
|
467 |
+
ffeats_xz = ffeats[2].permute(0,
|
468 |
+
2, 1, 3).reshape(B * N, S, self.latent_dim)
|
469 |
+
ffeats_ = ffeats_xy + ffeats_yz + ffeats_xz
|
470 |
+
|
471 |
+
if track_mask.shape[1] < vis_init.shape[1]:
|
472 |
+
track_mask = torch.cat(
|
473 |
+
[
|
474 |
+
track_mask,
|
475 |
+
torch.zeros_like(track_mask[:, 0]).repeat(
|
476 |
+
1, vis_init.shape[1] - track_mask.shape[1], 1, 1
|
477 |
+
),
|
478 |
+
],
|
479 |
+
dim=1,
|
480 |
+
)
|
481 |
+
concat = (
|
482 |
+
torch.cat([track_mask, vis_init], dim=2)
|
483 |
+
.permute(0, 2, 1, 3)
|
484 |
+
.reshape(B * N, S, 2)
|
485 |
+
)
|
486 |
+
|
487 |
+
transformer_input = torch.cat([flows_cat, fcorrs_, ffeats_, concat], dim=2)
|
488 |
+
|
489 |
+
if transformer_input.shape[-1] < pos_embed.shape[-1]:
|
490 |
+
# padding the transformer_input to the same dimension as pos_embed
|
491 |
+
transformer_input = F.pad(
|
492 |
+
transformer_input, (0, pos_embed.shape[-1] - transformer_input.shape[-1]),
|
493 |
+
"constant", 0
|
494 |
+
)
|
495 |
+
|
496 |
+
x = transformer_input + pos_embed + times_embed
|
497 |
+
x = rearrange(x, "(b n) t d -> b n t d", b=B)
|
498 |
+
|
499 |
+
delta, AttnMap, so3_dist, delta_se3F, so3 = self.updateformer(x, support_feat)
|
500 |
+
support_feat = support_feat + delta_se3F[0]/100
|
501 |
+
delta = rearrange(delta, " b n t d -> (b n) t d")
|
502 |
+
d_coord = delta[:, :, :3]
|
503 |
+
d_feats = delta[:, :, 3:]
|
504 |
+
|
505 |
+
ffeats_xy = self.ffeat_updater(self.norm(d_feats.view(-1, self.latent_dim))) + ffeats_xy.reshape(-1, self.latent_dim)
|
506 |
+
ffeats_yz = self.ffeatyz_updater(self.norm(d_feats.view(-1, self.latent_dim))) + ffeats_yz.reshape(-1, self.latent_dim)
|
507 |
+
ffeats_xz = self.ffeatxz_updater(self.norm(d_feats.view(-1, self.latent_dim))) + ffeats_xz.reshape(-1, self.latent_dim)
|
508 |
+
ffeats[0] = ffeats_xy.reshape(B, N, S, self.latent_dim).permute(
|
509 |
+
0, 2, 1, 3
|
510 |
+
) # B,S,N,C
|
511 |
+
ffeats[1] = ffeats_yz.reshape(B, N, S, self.latent_dim).permute(
|
512 |
+
0, 2, 1, 3
|
513 |
+
) # B,S,N,C
|
514 |
+
ffeats[2] = ffeats_xz.reshape(B, N, S, self.latent_dim).permute(
|
515 |
+
0, 2, 1, 3
|
516 |
+
) # B,S,N,C
|
517 |
+
coords = coords + d_coord.reshape(B, N, S, 3).permute(0, 2, 1, 3)
|
518 |
+
if torch.isnan(coords).any():
|
519 |
+
import ipdb; ipdb.set_trace()
|
520 |
+
|
521 |
+
coords_out = coords.clone()
|
522 |
+
coords_out[..., :2] *= float(self.stride)
|
523 |
+
|
524 |
+
coords_out[..., 2] = coords_out[..., 2]/self.Dz
|
525 |
+
coords_out[..., 2] = coords_out[..., 2]*(self.d_far-self.d_near) + self.d_near
|
526 |
+
|
527 |
+
coord_predictions.append(coords_out)
|
528 |
+
attn_predictions.append(AttnMap)
|
529 |
+
|
530 |
+
ffeats_f = ffeats[0] + ffeats[1] + ffeats[2]
|
531 |
+
vis_e = self.vis_predictor(ffeats_f.reshape(B * S * N, self.latent_dim)).reshape(
|
532 |
+
B, S, N
|
533 |
+
)
|
534 |
+
self.support_features = support_feat.detach()
|
535 |
+
return coord_predictions, attn_predictions, vis_e, feat_init, Rot_ln
|
536 |
+
|
537 |
+
|
538 |
+
def forward(self, rgbds, queries, iters=4, feat_init=None,
|
539 |
+
is_train=False, intrs=None, wind_S=None):
|
540 |
+
self.support_features = torch.zeros(100, 384).to("cuda") + 0.1
|
541 |
+
self.is_train=is_train
|
542 |
+
B, T, C, H, W = rgbds.shape
|
543 |
+
# set the intrinsic or simply initialized
|
544 |
+
if intrs is None:
|
545 |
+
intrs = torch.from_numpy(np.array([[W, 0.0, W//2],
|
546 |
+
[0.0, W, H//2],
|
547 |
+
[0.0, 0.0, 1.0]]))
|
548 |
+
intrs = intrs[None,
|
549 |
+
None,...].repeat(B, T, 1, 1).float().to(rgbds.device)
|
550 |
+
self.intrs = intrs
|
551 |
+
|
552 |
+
# prepare the input for tracking
|
553 |
+
(
|
554 |
+
rgbds,
|
555 |
+
first_positive_inds,
|
556 |
+
first_positive_sorted_inds, sort_inds,
|
557 |
+
inv_sort_inds, track_mask, gridxy,
|
558 |
+
coords_init, vis_init, Traj_arap
|
559 |
+
) = self.prepare_track(rgbds.clone(), queries)
|
560 |
+
coords_init_ = coords_init.clone()
|
561 |
+
vis_init_ = vis_init[:, :, sort_inds].clone()
|
562 |
+
|
563 |
+
depth_all = rgbds[:, :, 3,...]
|
564 |
+
d_near = self.d_near = depth_all[depth_all>0.01].min().item()
|
565 |
+
d_far = self.d_far = depth_all[depth_all>0.01].max().item()
|
566 |
+
|
567 |
+
if wind_S is not None:
|
568 |
+
self.S = wind_S
|
569 |
+
|
570 |
+
B, N, __ = queries.shape
|
571 |
+
self.Dz = Dz = W//self.stride
|
572 |
+
w_idx_start = 0
|
573 |
+
p_idx_end = 0
|
574 |
+
p_idx_start = 0
|
575 |
+
fmaps_ = None
|
576 |
+
vis_predictions = []
|
577 |
+
coord_predictions = []
|
578 |
+
attn_predictions = []
|
579 |
+
p_idx_end_list = []
|
580 |
+
Rigid_ln_total = 0
|
581 |
+
while w_idx_start < T - self.S // 2:
|
582 |
+
curr_wind_points = torch.nonzero(
|
583 |
+
first_positive_sorted_inds < w_idx_start + self.S)
|
584 |
+
if curr_wind_points.shape[0] == 0:
|
585 |
+
w_idx_start = w_idx_start + self.S // 2
|
586 |
+
continue
|
587 |
+
p_idx_end = curr_wind_points[-1] + 1
|
588 |
+
p_idx_end_list.append(p_idx_end)
|
589 |
+
# the T may not be divided by self.S
|
590 |
+
rgbds_seq = rgbds[:, w_idx_start:w_idx_start + self.S].clone()
|
591 |
+
S = S_local = rgbds_seq.shape[1]
|
592 |
+
if S < self.S:
|
593 |
+
rgbds_seq = torch.cat(
|
594 |
+
[rgbds_seq,
|
595 |
+
rgbds_seq[:, -1, None].repeat(1, self.S - S, 1, 1, 1)],
|
596 |
+
dim=1,
|
597 |
+
)
|
598 |
+
S = rgbds_seq.shape[1]
|
599 |
+
|
600 |
+
rgbs_ = rgbds_seq.reshape(B * S, C, H, W)[:, :3]
|
601 |
+
depths = rgbds_seq.reshape(B * S, C, H, W)[:, 3:].clone()
|
602 |
+
# open the mask
|
603 |
+
# Traj_arap[:, w_idx_start:w_idx_start + self.S, :p_idx_end, -1] = 0
|
604 |
+
#step1: normalize the depth map
|
605 |
+
|
606 |
+
depths = (depths - d_near)/(d_far-d_near)
|
607 |
+
depths_dn = nn.functional.interpolate(
|
608 |
+
depths, scale_factor=1.0 / self.stride, mode="nearest")
|
609 |
+
depths_dnG = depths_dn*Dz
|
610 |
+
|
611 |
+
#step2: normalize the coordinate
|
612 |
+
coords_init_[:, :, p_idx_start:p_idx_end, 2] = (
|
613 |
+
coords_init[:, :, p_idx_start:p_idx_end, 2] - d_near
|
614 |
+
)/(d_far-d_near)
|
615 |
+
coords_init_[:, :, p_idx_start:p_idx_end, 2] *= Dz
|
616 |
+
|
617 |
+
# efficient triplane splatting
|
618 |
+
gridxyz = torch.cat([gridxy[None,...].repeat(
|
619 |
+
depths_dn.shape[0],1,1,1), depths_dnG], dim=1)
|
620 |
+
Fxy2yz = gridxyz[:,[1, 2], ...] - gridxyz[:,:2]
|
621 |
+
Fxy2xz = gridxyz[:,[0, 2], ...] - gridxyz[:,:2]
|
622 |
+
if getattr(self.args, "Embed3D", None) == True:
|
623 |
+
gridxyz_nm = gridxyz.clone()
|
624 |
+
gridxyz_nm[:,0,...] = (gridxyz_nm[:,0,...]-gridxyz_nm[:,0,...].min())/(gridxyz_nm[:,0,...].max()-gridxyz_nm[:,0,...].min())
|
625 |
+
gridxyz_nm[:,1,...] = (gridxyz_nm[:,1,...]-gridxyz_nm[:,1,...].min())/(gridxyz_nm[:,1,...].max()-gridxyz_nm[:,1,...].min())
|
626 |
+
gridxyz_nm[:,2,...] = (gridxyz_nm[:,2,...]-gridxyz_nm[:,2,...].min())/(gridxyz_nm[:,2,...].max()-gridxyz_nm[:,2,...].min())
|
627 |
+
gridxyz_nm = 2*(gridxyz_nm-0.5)
|
628 |
+
_,_,h4,w4 = gridxyz_nm.shape
|
629 |
+
gridxyz_nm = gridxyz_nm.permute(0,2,3,1).reshape(S*h4*w4, 3)
|
630 |
+
featPE = self.embed3d(gridxyz_nm).view(S, h4, w4, -1).permute(0,3,1,2)
|
631 |
+
if fmaps_ is None:
|
632 |
+
fmaps_ = torch.cat([self.fnet(rgbs_),featPE], dim=1)
|
633 |
+
fmaps_ = self.embedConv(fmaps_)
|
634 |
+
else:
|
635 |
+
fmaps_new = torch.cat([self.fnet(rgbs_[self.S // 2 :]),featPE[self.S // 2 :]], dim=1)
|
636 |
+
fmaps_new = self.embedConv(fmaps_new)
|
637 |
+
fmaps_ = torch.cat(
|
638 |
+
[fmaps_[self.S // 2 :], fmaps_new], dim=0
|
639 |
+
)
|
640 |
+
else:
|
641 |
+
if fmaps_ is None:
|
642 |
+
fmaps_ = self.fnet(rgbs_)
|
643 |
+
else:
|
644 |
+
fmaps_ = torch.cat(
|
645 |
+
[fmaps_[self.S // 2 :], self.fnet(rgbs_[self.S // 2 :])], dim=0
|
646 |
+
)
|
647 |
+
|
648 |
+
fmapXY = fmaps_[:, :self.latent_dim].reshape(
|
649 |
+
B, S, self.latent_dim, H // self.stride, W // self.stride
|
650 |
+
)
|
651 |
+
|
652 |
+
fmapYZ = softsplat(fmapXY[0], Fxy2yz, None,
|
653 |
+
strMode="avg", tenoutH=self.Dz, tenoutW=H//self.stride)
|
654 |
+
fmapXZ = softsplat(fmapXY[0], Fxy2xz, None,
|
655 |
+
strMode="avg", tenoutH=self.Dz, tenoutW=W//self.stride)
|
656 |
+
|
657 |
+
fmapYZ = self.headyz(fmapYZ)[None, ...]
|
658 |
+
fmapXZ = self.headxz(fmapXZ)[None, ...]
|
659 |
+
|
660 |
+
if p_idx_end - p_idx_start > 0:
|
661 |
+
queried_t = (first_positive_sorted_inds[p_idx_start:p_idx_end]
|
662 |
+
- w_idx_start)
|
663 |
+
(featxy_init,
|
664 |
+
featyz_init,
|
665 |
+
featxz_init) = self.sample_trifeat(
|
666 |
+
t=queried_t,featMapxy=fmapXY,
|
667 |
+
featMapyz=fmapYZ,featMapxz=fmapXZ,
|
668 |
+
coords=coords_init_[:, :1, p_idx_start:p_idx_end]
|
669 |
+
)
|
670 |
+
# T, S, N, C, 3
|
671 |
+
feat_init_curr = torch.stack([featxy_init,
|
672 |
+
featyz_init, featxz_init], dim=-1)
|
673 |
+
feat_init = smart_cat(feat_init, feat_init_curr, dim=2)
|
674 |
+
|
675 |
+
if p_idx_start > 0:
|
676 |
+
# preprocess the coordinates of last windows
|
677 |
+
last_coords = coords[-1][:, self.S // 2 :].clone()
|
678 |
+
last_coords[..., :2] /= float(self.stride)
|
679 |
+
last_coords[..., 2:] = (last_coords[..., 2:]-d_near)/(d_far-d_near)
|
680 |
+
last_coords[..., 2:] = last_coords[..., 2:]*Dz
|
681 |
+
|
682 |
+
coords_init_[:, : self.S // 2, :p_idx_start] = last_coords
|
683 |
+
coords_init_[:, self.S // 2 :, :p_idx_start] = last_coords[
|
684 |
+
:, -1
|
685 |
+
].repeat(1, self.S // 2, 1, 1)
|
686 |
+
|
687 |
+
last_vis = vis[:, self.S // 2 :].unsqueeze(-1)
|
688 |
+
vis_init_[:, : self.S // 2, :p_idx_start] = last_vis
|
689 |
+
vis_init_[:, self.S // 2 :, :p_idx_start] = last_vis[:, -1].repeat(
|
690 |
+
1, self.S // 2, 1, 1
|
691 |
+
)
|
692 |
+
|
693 |
+
coords, attns, vis, __, Rigid_ln = self.forward_iteration(
|
694 |
+
fmapXY=fmapXY,
|
695 |
+
fmapYZ=fmapYZ,
|
696 |
+
fmapXZ=fmapXZ,
|
697 |
+
coords_init=coords_init_[:, :, :p_idx_end],
|
698 |
+
feat_init=feat_init[:, :, :p_idx_end],
|
699 |
+
vis_init=vis_init_[:, :, :p_idx_end],
|
700 |
+
track_mask=track_mask[:, w_idx_start : w_idx_start + self.S, :p_idx_end],
|
701 |
+
iters=iters,
|
702 |
+
intrs_S=self.intrs[:, w_idx_start : w_idx_start + self.S],
|
703 |
+
)
|
704 |
+
|
705 |
+
Rigid_ln_total+=Rigid_ln
|
706 |
+
|
707 |
+
if is_train:
|
708 |
+
vis_predictions.append(torch.sigmoid(vis[:, :S_local]))
|
709 |
+
coord_predictions.append([coord[:, :S_local] for coord in coords])
|
710 |
+
attn_predictions.append(attns)
|
711 |
+
|
712 |
+
self.traj_e[:, w_idx_start:w_idx_start+self.S, :p_idx_end] = coords[-1][:, :S_local]
|
713 |
+
self.vis_e[:, w_idx_start:w_idx_start+self.S, :p_idx_end] = vis[:, :S_local]
|
714 |
+
|
715 |
+
track_mask[:, : w_idx_start + self.S, :p_idx_end] = 0.0
|
716 |
+
w_idx_start = w_idx_start + self.S // 2
|
717 |
+
|
718 |
+
p_idx_start = p_idx_end
|
719 |
+
|
720 |
+
self.traj_e = self.traj_e[:, :, inv_sort_inds]
|
721 |
+
self.vis_e = self.vis_e[:, :, inv_sort_inds]
|
722 |
+
|
723 |
+
self.vis_e = torch.sigmoid(self.vis_e)
|
724 |
+
train_data = (
|
725 |
+
(vis_predictions, coord_predictions, attn_predictions,
|
726 |
+
p_idx_end_list, sort_inds, Rigid_ln_total)
|
727 |
+
)
|
728 |
+
if self.is_train:
|
729 |
+
return self.traj_e, feat_init, self.vis_e, train_data
|
730 |
+
else:
|
731 |
+
return self.traj_e, feat_init, self.vis_e
|
732 |
+
|
models/spatracker/models/core/spatracker/unet.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Codes are from:
|
3 |
+
https://github.com/jaxony/unet-pytorch/blob/master/model.py
|
4 |
+
'''
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch.autograd import Variable
|
10 |
+
from collections import OrderedDict
|
11 |
+
from torch.nn import init
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
def conv3x3(in_channels, out_channels, stride=1,
|
15 |
+
padding=1, bias=True, groups=1):
|
16 |
+
return nn.Conv2d(
|
17 |
+
in_channels,
|
18 |
+
out_channels,
|
19 |
+
kernel_size=3,
|
20 |
+
stride=stride,
|
21 |
+
padding=padding,
|
22 |
+
bias=bias,
|
23 |
+
groups=groups)
|
24 |
+
|
25 |
+
def upconv2x2(in_channels, out_channels, mode='transpose'):
|
26 |
+
if mode == 'transpose':
|
27 |
+
return nn.ConvTranspose2d(
|
28 |
+
in_channels,
|
29 |
+
out_channels,
|
30 |
+
kernel_size=2,
|
31 |
+
stride=2)
|
32 |
+
else:
|
33 |
+
# out_channels is always going to be the same
|
34 |
+
# as in_channels
|
35 |
+
return nn.Sequential(
|
36 |
+
nn.Upsample(mode='bilinear', scale_factor=2),
|
37 |
+
conv1x1(in_channels, out_channels))
|
38 |
+
|
39 |
+
def conv1x1(in_channels, out_channels, groups=1):
|
40 |
+
return nn.Conv2d(
|
41 |
+
in_channels,
|
42 |
+
out_channels,
|
43 |
+
kernel_size=1,
|
44 |
+
groups=groups,
|
45 |
+
stride=1)
|
46 |
+
|
47 |
+
|
48 |
+
class DownConv(nn.Module):
|
49 |
+
"""
|
50 |
+
A helper Module that performs 2 convolutions and 1 MaxPool.
|
51 |
+
A ReLU activation follows each convolution.
|
52 |
+
"""
|
53 |
+
def __init__(self, in_channels, out_channels, pooling=True):
|
54 |
+
super(DownConv, self).__init__()
|
55 |
+
|
56 |
+
self.in_channels = in_channels
|
57 |
+
self.out_channels = out_channels
|
58 |
+
self.pooling = pooling
|
59 |
+
|
60 |
+
self.conv1 = conv3x3(self.in_channels, self.out_channels)
|
61 |
+
self.conv2 = conv3x3(self.out_channels, self.out_channels)
|
62 |
+
|
63 |
+
if self.pooling:
|
64 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
x = F.relu(self.conv1(x))
|
68 |
+
x = F.relu(self.conv2(x))
|
69 |
+
before_pool = x
|
70 |
+
if self.pooling:
|
71 |
+
x = self.pool(x)
|
72 |
+
return x, before_pool
|
73 |
+
|
74 |
+
|
75 |
+
class UpConv(nn.Module):
|
76 |
+
"""
|
77 |
+
A helper Module that performs 2 convolutions and 1 UpConvolution.
|
78 |
+
A ReLU activation follows each convolution.
|
79 |
+
"""
|
80 |
+
def __init__(self, in_channels, out_channels,
|
81 |
+
merge_mode='concat', up_mode='transpose'):
|
82 |
+
super(UpConv, self).__init__()
|
83 |
+
|
84 |
+
self.in_channels = in_channels
|
85 |
+
self.out_channels = out_channels
|
86 |
+
self.merge_mode = merge_mode
|
87 |
+
self.up_mode = up_mode
|
88 |
+
|
89 |
+
self.upconv = upconv2x2(self.in_channels, self.out_channels,
|
90 |
+
mode=self.up_mode)
|
91 |
+
|
92 |
+
if self.merge_mode == 'concat':
|
93 |
+
self.conv1 = conv3x3(
|
94 |
+
2*self.out_channels, self.out_channels)
|
95 |
+
else:
|
96 |
+
# num of input channels to conv2 is same
|
97 |
+
self.conv1 = conv3x3(self.out_channels, self.out_channels)
|
98 |
+
self.conv2 = conv3x3(self.out_channels, self.out_channels)
|
99 |
+
|
100 |
+
|
101 |
+
def forward(self, from_down, from_up):
|
102 |
+
""" Forward pass
|
103 |
+
Arguments:
|
104 |
+
from_down: tensor from the encoder pathway
|
105 |
+
from_up: upconv'd tensor from the decoder pathway
|
106 |
+
"""
|
107 |
+
from_up = self.upconv(from_up)
|
108 |
+
if self.merge_mode == 'concat':
|
109 |
+
x = torch.cat((from_up, from_down), 1)
|
110 |
+
else:
|
111 |
+
x = from_up + from_down
|
112 |
+
x = F.relu(self.conv1(x))
|
113 |
+
x = F.relu(self.conv2(x))
|
114 |
+
return x
|
115 |
+
|
116 |
+
|
117 |
+
class UNet(nn.Module):
|
118 |
+
""" `UNet` class is based on https://arxiv.org/abs/1505.04597
|
119 |
+
|
120 |
+
The U-Net is a convolutional encoder-decoder neural network.
|
121 |
+
Contextual spatial information (from the decoding,
|
122 |
+
expansive pathway) about an input tensor is merged with
|
123 |
+
information representing the localization of details
|
124 |
+
(from the encoding, compressive pathway).
|
125 |
+
|
126 |
+
Modifications to the original paper:
|
127 |
+
(1) padding is used in 3x3 convolutions to prevent loss
|
128 |
+
of border pixels
|
129 |
+
(2) merging outputs does not require cropping due to (1)
|
130 |
+
(3) residual connections can be used by specifying
|
131 |
+
UNet(merge_mode='add')
|
132 |
+
(4) if non-parametric upsampling is used in the decoder
|
133 |
+
pathway (specified by upmode='upsample'), then an
|
134 |
+
additional 1x1 2d convolution occurs after upsampling
|
135 |
+
to reduce channel dimensionality by a factor of 2.
|
136 |
+
This channel halving happens with the convolution in
|
137 |
+
the tranpose convolution (specified by upmode='transpose')
|
138 |
+
"""
|
139 |
+
|
140 |
+
def __init__(self, num_classes, in_channels=3, depth=5,
|
141 |
+
start_filts=64, up_mode='transpose',
|
142 |
+
merge_mode='concat', **kwargs):
|
143 |
+
"""
|
144 |
+
Arguments:
|
145 |
+
in_channels: int, number of channels in the input tensor.
|
146 |
+
Default is 3 for RGB images.
|
147 |
+
depth: int, number of MaxPools in the U-Net.
|
148 |
+
start_filts: int, number of convolutional filters for the
|
149 |
+
first conv.
|
150 |
+
up_mode: string, type of upconvolution. Choices: 'transpose'
|
151 |
+
for transpose convolution or 'upsample' for nearest neighbour
|
152 |
+
upsampling.
|
153 |
+
"""
|
154 |
+
super(UNet, self).__init__()
|
155 |
+
|
156 |
+
if up_mode in ('transpose', 'upsample'):
|
157 |
+
self.up_mode = up_mode
|
158 |
+
else:
|
159 |
+
raise ValueError("\"{}\" is not a valid mode for "
|
160 |
+
"upsampling. Only \"transpose\" and "
|
161 |
+
"\"upsample\" are allowed.".format(up_mode))
|
162 |
+
|
163 |
+
if merge_mode in ('concat', 'add'):
|
164 |
+
self.merge_mode = merge_mode
|
165 |
+
else:
|
166 |
+
raise ValueError("\"{}\" is not a valid mode for"
|
167 |
+
"merging up and down paths. "
|
168 |
+
"Only \"concat\" and "
|
169 |
+
"\"add\" are allowed.".format(up_mode))
|
170 |
+
|
171 |
+
# NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
|
172 |
+
if self.up_mode == 'upsample' and self.merge_mode == 'add':
|
173 |
+
raise ValueError("up_mode \"upsample\" is incompatible "
|
174 |
+
"with merge_mode \"add\" at the moment "
|
175 |
+
"because it doesn't make sense to use "
|
176 |
+
"nearest neighbour to reduce "
|
177 |
+
"depth channels (by half).")
|
178 |
+
|
179 |
+
self.num_classes = num_classes
|
180 |
+
self.in_channels = in_channels
|
181 |
+
self.start_filts = start_filts
|
182 |
+
self.depth = depth
|
183 |
+
|
184 |
+
self.down_convs = []
|
185 |
+
self.up_convs = []
|
186 |
+
|
187 |
+
# create the encoder pathway and add to a list
|
188 |
+
for i in range(depth):
|
189 |
+
ins = self.in_channels if i == 0 else outs
|
190 |
+
outs = self.start_filts*(2**i)
|
191 |
+
pooling = True if i < depth-1 else False
|
192 |
+
|
193 |
+
down_conv = DownConv(ins, outs, pooling=pooling)
|
194 |
+
self.down_convs.append(down_conv)
|
195 |
+
|
196 |
+
# create the decoder pathway and add to a list
|
197 |
+
# - careful! decoding only requires depth-1 blocks
|
198 |
+
for i in range(depth-1):
|
199 |
+
ins = outs
|
200 |
+
outs = ins // 2
|
201 |
+
up_conv = UpConv(ins, outs, up_mode=up_mode,
|
202 |
+
merge_mode=merge_mode)
|
203 |
+
self.up_convs.append(up_conv)
|
204 |
+
|
205 |
+
# add the list of modules to current module
|
206 |
+
self.down_convs = nn.ModuleList(self.down_convs)
|
207 |
+
self.up_convs = nn.ModuleList(self.up_convs)
|
208 |
+
|
209 |
+
self.conv_final = conv1x1(outs, self.num_classes)
|
210 |
+
|
211 |
+
self.reset_params()
|
212 |
+
|
213 |
+
@staticmethod
|
214 |
+
def weight_init(m):
|
215 |
+
if isinstance(m, nn.Conv2d):
|
216 |
+
init.xavier_normal_(m.weight)
|
217 |
+
init.constant_(m.bias, 0)
|
218 |
+
|
219 |
+
|
220 |
+
def reset_params(self):
|
221 |
+
for i, m in enumerate(self.modules()):
|
222 |
+
self.weight_init(m)
|
223 |
+
|
224 |
+
|
225 |
+
def forward(self, x):
|
226 |
+
encoder_outs = []
|
227 |
+
# encoder pathway, save outputs for merging
|
228 |
+
for i, module in enumerate(self.down_convs):
|
229 |
+
x, before_pool = module(x)
|
230 |
+
encoder_outs.append(before_pool)
|
231 |
+
for i, module in enumerate(self.up_convs):
|
232 |
+
before_pool = encoder_outs[-(i+2)]
|
233 |
+
x = module(before_pool, x)
|
234 |
+
|
235 |
+
# No softmax is used. This means you need to use
|
236 |
+
# nn.CrossEntropyLoss is your training script,
|
237 |
+
# as this module includes a softmax already.
|
238 |
+
x = self.conv_final(x)
|
239 |
+
return x
|
240 |
+
|
241 |
+
if __name__ == "__main__":
|
242 |
+
"""
|
243 |
+
testing
|
244 |
+
"""
|
245 |
+
model = UNet(1, depth=5, merge_mode='concat', in_channels=1, start_filts=32)
|
246 |
+
print(model)
|
247 |
+
print(sum(p.numel() for p in model.parameters()))
|
248 |
+
|
249 |
+
reso = 176
|
250 |
+
x = np.zeros((1, 1, reso, reso))
|
251 |
+
x[:,:,int(reso/2-1), int(reso/2-1)] = np.nan
|
252 |
+
x = torch.FloatTensor(x)
|
253 |
+
|
254 |
+
out = model(x)
|
255 |
+
print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso)))
|
256 |
+
|
257 |
+
# loss = torch.sum(out)
|
258 |
+
# loss.backward()
|
models/spatracker/models/core/spatracker/vit/__init__.py
ADDED
File without changes
|
models/spatracker/models/core/spatracker/vit/common.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from typing import Type
|
11 |
+
|
12 |
+
|
13 |
+
class MLPBlock(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
embedding_dim: int,
|
17 |
+
mlp_dim: int,
|
18 |
+
act: Type[nn.Module] = nn.GELU,
|
19 |
+
) -> None:
|
20 |
+
super().__init__()
|
21 |
+
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
22 |
+
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
23 |
+
self.act = act()
|
24 |
+
|
25 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
26 |
+
return self.lin2(self.act(self.lin1(x)))
|
27 |
+
|
28 |
+
|
29 |
+
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
|
30 |
+
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
|
31 |
+
class LayerNorm2d(nn.Module):
|
32 |
+
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
33 |
+
super().__init__()
|
34 |
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
35 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
36 |
+
self.eps = eps
|
37 |
+
|
38 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
39 |
+
u = x.mean(1, keepdim=True)
|
40 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
41 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
42 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
43 |
+
return x
|
models/spatracker/models/core/spatracker/vit/encoder.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
from typing import Optional, Tuple, Type
|
12 |
+
|
13 |
+
from models.spatracker.models.core.spatracker.vit.common import (
|
14 |
+
LayerNorm2d, MLPBlock
|
15 |
+
)
|
16 |
+
|
17 |
+
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
|
18 |
+
class ImageEncoderViT(nn.Module):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
img_size: int = 1024,
|
22 |
+
patch_size: int = 16,
|
23 |
+
in_chans: int = 3,
|
24 |
+
embed_dim: int = 768,
|
25 |
+
depth: int = 12,
|
26 |
+
num_heads: int = 12,
|
27 |
+
mlp_ratio: float = 4.0,
|
28 |
+
out_chans: int = 256,
|
29 |
+
qkv_bias: bool = True,
|
30 |
+
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
31 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
32 |
+
use_abs_pos: bool = True,
|
33 |
+
use_rel_pos: bool = False,
|
34 |
+
rel_pos_zero_init: bool = True,
|
35 |
+
window_size: int = 0,
|
36 |
+
global_attn_indexes: Tuple[int, ...] = (),
|
37 |
+
) -> None:
|
38 |
+
"""
|
39 |
+
Args:
|
40 |
+
img_size (int): Input image size.
|
41 |
+
patch_size (int): Patch size.
|
42 |
+
in_chans (int): Number of input image channels.
|
43 |
+
embed_dim (int): Patch embedding dimension.
|
44 |
+
depth (int): Depth of ViT.
|
45 |
+
num_heads (int): Number of attention heads in each ViT block.
|
46 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
47 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
48 |
+
norm_layer (nn.Module): Normalization layer.
|
49 |
+
act_layer (nn.Module): Activation layer.
|
50 |
+
use_abs_pos (bool): If True, use absolute positional embeddings.
|
51 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
52 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
53 |
+
window_size (int): Window size for window attention blocks.
|
54 |
+
global_attn_indexes (list): Indexes for blocks using global attention.
|
55 |
+
"""
|
56 |
+
super().__init__()
|
57 |
+
self.img_size = img_size
|
58 |
+
|
59 |
+
self.patch_embed = PatchEmbed(
|
60 |
+
kernel_size=(patch_size, patch_size),
|
61 |
+
stride=(patch_size, patch_size),
|
62 |
+
in_chans=in_chans,
|
63 |
+
embed_dim=embed_dim,
|
64 |
+
)
|
65 |
+
|
66 |
+
self.pos_embed: Optional[nn.Parameter] = None
|
67 |
+
if use_abs_pos:
|
68 |
+
# Initialize absolute positional embedding with pretrain image size.
|
69 |
+
self.pos_embed = nn.Parameter(
|
70 |
+
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
|
71 |
+
)
|
72 |
+
|
73 |
+
self.blocks = nn.ModuleList()
|
74 |
+
for i in range(depth):
|
75 |
+
block = Block(
|
76 |
+
dim=embed_dim,
|
77 |
+
num_heads=num_heads,
|
78 |
+
mlp_ratio=mlp_ratio,
|
79 |
+
qkv_bias=qkv_bias,
|
80 |
+
norm_layer=norm_layer,
|
81 |
+
act_layer=act_layer,
|
82 |
+
use_rel_pos=use_rel_pos,
|
83 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
84 |
+
window_size=window_size if i not in global_attn_indexes else 0,
|
85 |
+
input_size=(img_size // patch_size, img_size // patch_size),
|
86 |
+
)
|
87 |
+
self.blocks.append(block)
|
88 |
+
|
89 |
+
self.neck = nn.Sequential(
|
90 |
+
nn.Conv2d(
|
91 |
+
embed_dim,
|
92 |
+
out_chans,
|
93 |
+
kernel_size=1,
|
94 |
+
bias=False,
|
95 |
+
),
|
96 |
+
LayerNorm2d(out_chans),
|
97 |
+
nn.Conv2d(
|
98 |
+
out_chans,
|
99 |
+
out_chans,
|
100 |
+
kernel_size=3,
|
101 |
+
padding=1,
|
102 |
+
bias=False,
|
103 |
+
),
|
104 |
+
LayerNorm2d(out_chans),
|
105 |
+
)
|
106 |
+
|
107 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
108 |
+
|
109 |
+
x = self.patch_embed(x)
|
110 |
+
if self.pos_embed is not None:
|
111 |
+
x = x + self.pos_embed
|
112 |
+
|
113 |
+
for blk in self.blocks:
|
114 |
+
x = blk(x)
|
115 |
+
|
116 |
+
x = self.neck(x.permute(0, 3, 1, 2))
|
117 |
+
|
118 |
+
return x
|
119 |
+
|
120 |
+
|
121 |
+
class Block(nn.Module):
|
122 |
+
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
123 |
+
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
dim: int,
|
127 |
+
num_heads: int,
|
128 |
+
mlp_ratio: float = 4.0,
|
129 |
+
qkv_bias: bool = True,
|
130 |
+
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
131 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
132 |
+
use_rel_pos: bool = False,
|
133 |
+
rel_pos_zero_init: bool = True,
|
134 |
+
window_size: int = 0,
|
135 |
+
input_size: Optional[Tuple[int, int]] = None,
|
136 |
+
) -> None:
|
137 |
+
"""
|
138 |
+
Args:
|
139 |
+
dim (int): Number of input channels.
|
140 |
+
num_heads (int): Number of attention heads in each ViT block.
|
141 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
142 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
143 |
+
norm_layer (nn.Module): Normalization layer.
|
144 |
+
act_layer (nn.Module): Activation layer.
|
145 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
146 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
147 |
+
window_size (int): Window size for window attention blocks. If it equals 0, then
|
148 |
+
use global attention.
|
149 |
+
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
150 |
+
positional parameter size.
|
151 |
+
"""
|
152 |
+
super().__init__()
|
153 |
+
self.norm1 = norm_layer(dim)
|
154 |
+
self.attn = Attention(
|
155 |
+
dim,
|
156 |
+
num_heads=num_heads,
|
157 |
+
qkv_bias=qkv_bias,
|
158 |
+
use_rel_pos=use_rel_pos,
|
159 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
160 |
+
input_size=input_size if window_size == 0 else (window_size, window_size),
|
161 |
+
)
|
162 |
+
|
163 |
+
self.norm2 = norm_layer(dim)
|
164 |
+
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
|
165 |
+
|
166 |
+
self.window_size = window_size
|
167 |
+
|
168 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
169 |
+
shortcut = x
|
170 |
+
x = self.norm1(x)
|
171 |
+
# Window partition
|
172 |
+
if self.window_size > 0:
|
173 |
+
H, W = x.shape[1], x.shape[2]
|
174 |
+
x, pad_hw = window_partition(x, self.window_size)
|
175 |
+
|
176 |
+
x = self.attn(x)
|
177 |
+
# Reverse window partition
|
178 |
+
if self.window_size > 0:
|
179 |
+
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
180 |
+
|
181 |
+
x = shortcut + x
|
182 |
+
x = x + self.mlp(self.norm2(x))
|
183 |
+
|
184 |
+
return x
|
185 |
+
|
186 |
+
|
187 |
+
class Attention(nn.Module):
|
188 |
+
"""Multi-head Attention block with relative position embeddings."""
|
189 |
+
|
190 |
+
def __init__(
|
191 |
+
self,
|
192 |
+
dim: int,
|
193 |
+
num_heads: int = 8,
|
194 |
+
qkv_bias: bool = True,
|
195 |
+
use_rel_pos: bool = False,
|
196 |
+
rel_pos_zero_init: bool = True,
|
197 |
+
input_size: Optional[Tuple[int, int]] = None,
|
198 |
+
) -> None:
|
199 |
+
"""
|
200 |
+
Args:
|
201 |
+
dim (int): Number of input channels.
|
202 |
+
num_heads (int): Number of attention heads.
|
203 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
204 |
+
rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
205 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
206 |
+
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
207 |
+
positional parameter size.
|
208 |
+
"""
|
209 |
+
super().__init__()
|
210 |
+
self.num_heads = num_heads
|
211 |
+
head_dim = dim // num_heads
|
212 |
+
self.scale = head_dim**-0.5
|
213 |
+
|
214 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
215 |
+
self.proj = nn.Linear(dim, dim)
|
216 |
+
|
217 |
+
self.use_rel_pos = use_rel_pos
|
218 |
+
if self.use_rel_pos:
|
219 |
+
assert (
|
220 |
+
input_size is not None
|
221 |
+
), "Input size must be provided if using relative positional encoding."
|
222 |
+
# initialize relative positional embeddings
|
223 |
+
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
224 |
+
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
225 |
+
|
226 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
227 |
+
B, H, W, _ = x.shape
|
228 |
+
# qkv with shape (3, B, nHead, H * W, C)
|
229 |
+
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
230 |
+
# q, k, v with shape (B * nHead, H * W, C)
|
231 |
+
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
232 |
+
|
233 |
+
attn = (q * self.scale) @ k.transpose(-2, -1)
|
234 |
+
|
235 |
+
if self.use_rel_pos:
|
236 |
+
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
|
237 |
+
|
238 |
+
attn = attn.softmax(dim=-1)
|
239 |
+
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
|
240 |
+
x = self.proj(x)
|
241 |
+
|
242 |
+
return x
|
243 |
+
|
244 |
+
|
245 |
+
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
246 |
+
"""
|
247 |
+
Partition into non-overlapping windows with padding if needed.
|
248 |
+
Args:
|
249 |
+
x (tensor): input tokens with [B, H, W, C].
|
250 |
+
window_size (int): window size.
|
251 |
+
|
252 |
+
Returns:
|
253 |
+
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
254 |
+
(Hp, Wp): padded height and width before partition
|
255 |
+
"""
|
256 |
+
B, H, W, C = x.shape
|
257 |
+
|
258 |
+
pad_h = (window_size - H % window_size) % window_size
|
259 |
+
pad_w = (window_size - W % window_size) % window_size
|
260 |
+
if pad_h > 0 or pad_w > 0:
|
261 |
+
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
262 |
+
Hp, Wp = H + pad_h, W + pad_w
|
263 |
+
|
264 |
+
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
265 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
266 |
+
return windows, (Hp, Wp)
|
267 |
+
|
268 |
+
|
269 |
+
def window_unpartition(
|
270 |
+
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
|
271 |
+
) -> torch.Tensor:
|
272 |
+
"""
|
273 |
+
Window unpartition into original sequences and removing padding.
|
274 |
+
Args:
|
275 |
+
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
276 |
+
window_size (int): window size.
|
277 |
+
pad_hw (Tuple): padded height and width (Hp, Wp).
|
278 |
+
hw (Tuple): original height and width (H, W) before padding.
|
279 |
+
|
280 |
+
Returns:
|
281 |
+
x: unpartitioned sequences with [B, H, W, C].
|
282 |
+
"""
|
283 |
+
Hp, Wp = pad_hw
|
284 |
+
H, W = hw
|
285 |
+
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
286 |
+
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
287 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
288 |
+
|
289 |
+
if Hp > H or Wp > W:
|
290 |
+
x = x[:, :H, :W, :].contiguous()
|
291 |
+
return x
|
292 |
+
|
293 |
+
|
294 |
+
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
295 |
+
"""
|
296 |
+
Get relative positional embeddings according to the relative positions of
|
297 |
+
query and key sizes.
|
298 |
+
Args:
|
299 |
+
q_size (int): size of query q.
|
300 |
+
k_size (int): size of key k.
|
301 |
+
rel_pos (Tensor): relative position embeddings (L, C).
|
302 |
+
|
303 |
+
Returns:
|
304 |
+
Extracted positional embeddings according to relative positions.
|
305 |
+
"""
|
306 |
+
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
307 |
+
# Interpolate rel pos if needed.
|
308 |
+
if rel_pos.shape[0] != max_rel_dist:
|
309 |
+
# Interpolate rel pos.
|
310 |
+
rel_pos_resized = F.interpolate(
|
311 |
+
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
312 |
+
size=max_rel_dist,
|
313 |
+
mode="linear",
|
314 |
+
)
|
315 |
+
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
316 |
+
else:
|
317 |
+
rel_pos_resized = rel_pos
|
318 |
+
|
319 |
+
# Scale the coords with short length if shapes for q and k are different.
|
320 |
+
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
321 |
+
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
322 |
+
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
323 |
+
|
324 |
+
return rel_pos_resized[relative_coords.long()]
|
325 |
+
|
326 |
+
|
327 |
+
def add_decomposed_rel_pos(
|
328 |
+
attn: torch.Tensor,
|
329 |
+
q: torch.Tensor,
|
330 |
+
rel_pos_h: torch.Tensor,
|
331 |
+
rel_pos_w: torch.Tensor,
|
332 |
+
q_size: Tuple[int, int],
|
333 |
+
k_size: Tuple[int, int],
|
334 |
+
) -> torch.Tensor:
|
335 |
+
"""
|
336 |
+
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
337 |
+
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
338 |
+
Args:
|
339 |
+
attn (Tensor): attention map.
|
340 |
+
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
341 |
+
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
342 |
+
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
343 |
+
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
344 |
+
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
345 |
+
|
346 |
+
Returns:
|
347 |
+
attn (Tensor): attention map with added relative positional embeddings.
|
348 |
+
"""
|
349 |
+
q_h, q_w = q_size
|
350 |
+
k_h, k_w = k_size
|
351 |
+
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
352 |
+
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
353 |
+
|
354 |
+
B, _, dim = q.shape
|
355 |
+
r_q = q.reshape(B, q_h, q_w, dim)
|
356 |
+
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
357 |
+
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
358 |
+
|
359 |
+
attn = (
|
360 |
+
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
361 |
+
).view(B, q_h * q_w, k_h * k_w)
|
362 |
+
|
363 |
+
return attn
|
364 |
+
|
365 |
+
|
366 |
+
class PatchEmbed(nn.Module):
|
367 |
+
"""
|
368 |
+
Image to Patch Embedding.
|
369 |
+
"""
|
370 |
+
|
371 |
+
def __init__(
|
372 |
+
self,
|
373 |
+
kernel_size: Tuple[int, int] = (16, 16),
|
374 |
+
stride: Tuple[int, int] = (16, 16),
|
375 |
+
padding: Tuple[int, int] = (0, 0),
|
376 |
+
in_chans: int = 3,
|
377 |
+
embed_dim: int = 768,
|
378 |
+
) -> None:
|
379 |
+
"""
|
380 |
+
Args:
|
381 |
+
kernel_size (Tuple): kernel size of the projection layer.
|
382 |
+
stride (Tuple): stride of the projection layer.
|
383 |
+
padding (Tuple): padding size of the projection layer.
|
384 |
+
in_chans (int): Number of input image channels.
|
385 |
+
embed_dim (int): Patch embedding dimension.
|
386 |
+
"""
|
387 |
+
super().__init__()
|
388 |
+
|
389 |
+
self.proj = nn.Conv2d(
|
390 |
+
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
391 |
+
)
|
392 |
+
|
393 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
394 |
+
x = self.proj(x)
|
395 |
+
# B C H W -> B H W C
|
396 |
+
x = x.permute(0, 2, 3, 1)
|
397 |
+
return x
|
models/spatracker/predictor.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import time
|
10 |
+
|
11 |
+
from tqdm import tqdm
|
12 |
+
from models.spatracker.models.core.spatracker.spatracker import get_points_on_a_grid
|
13 |
+
from models.spatracker.models.core.model_utils import smart_cat
|
14 |
+
from models.spatracker.models.build_spatracker import (
|
15 |
+
build_spatracker,
|
16 |
+
)
|
17 |
+
from models.spatracker.models.core.model_utils import (
|
18 |
+
meshgrid2d, bilinear_sample2d, smart_cat
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
class SpaTrackerPredictor(torch.nn.Module):
|
23 |
+
def __init__(
|
24 |
+
self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth",
|
25 |
+
interp_shape=(384, 512),
|
26 |
+
seq_length=16
|
27 |
+
):
|
28 |
+
super().__init__()
|
29 |
+
self.interp_shape = interp_shape
|
30 |
+
self.support_grid_size = 6
|
31 |
+
model = build_spatracker(checkpoint, seq_length=seq_length)
|
32 |
+
|
33 |
+
self.model = model
|
34 |
+
self.model.eval()
|
35 |
+
|
36 |
+
@torch.no_grad()
|
37 |
+
def forward(
|
38 |
+
self,
|
39 |
+
video, # (1, T, 3, H, W)
|
40 |
+
video_depth = None, # (T, 1, H, W)
|
41 |
+
# input prompt types:
|
42 |
+
# - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
|
43 |
+
# *backward_tracking=True* will compute tracks in both directions.
|
44 |
+
# - queries. Queried points of shape (1, N, 3) in format (t, x, y) for frame index and pixel coordinates.
|
45 |
+
# - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
|
46 |
+
# You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
|
47 |
+
queries: torch.Tensor = None,
|
48 |
+
segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W)
|
49 |
+
grid_size: int = 0,
|
50 |
+
grid_query_frame: int = 0, # only for dense and regular grid tracks
|
51 |
+
backward_tracking: bool = False,
|
52 |
+
depth_predictor=None,
|
53 |
+
wind_length: int = 8,
|
54 |
+
progressive_tracking: bool = False,
|
55 |
+
):
|
56 |
+
if queries is None and grid_size == 0:
|
57 |
+
tracks, visibilities, T_Firsts = self._compute_dense_tracks(
|
58 |
+
video,
|
59 |
+
grid_query_frame=grid_query_frame,
|
60 |
+
backward_tracking=backward_tracking,
|
61 |
+
video_depth=video_depth,
|
62 |
+
depth_predictor=depth_predictor,
|
63 |
+
wind_length=wind_length,
|
64 |
+
)
|
65 |
+
else:
|
66 |
+
tracks, visibilities, T_Firsts = self._compute_sparse_tracks(
|
67 |
+
video,
|
68 |
+
queries,
|
69 |
+
segm_mask,
|
70 |
+
grid_size,
|
71 |
+
add_support_grid=False, #(grid_size == 0 or segm_mask is not None),
|
72 |
+
grid_query_frame=grid_query_frame,
|
73 |
+
backward_tracking=backward_tracking,
|
74 |
+
video_depth=video_depth,
|
75 |
+
depth_predictor=depth_predictor,
|
76 |
+
wind_length=wind_length,
|
77 |
+
)
|
78 |
+
|
79 |
+
return tracks, visibilities, T_Firsts
|
80 |
+
|
81 |
+
def _compute_dense_tracks(
|
82 |
+
self, video, grid_query_frame, grid_size=30, backward_tracking=False,
|
83 |
+
depth_predictor=None, video_depth=None, wind_length=8
|
84 |
+
):
|
85 |
+
*_, H, W = video.shape
|
86 |
+
grid_step = W // grid_size
|
87 |
+
grid_width = W // grid_step
|
88 |
+
grid_height = H // grid_step
|
89 |
+
tracks = visibilities = T_Firsts = None
|
90 |
+
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
|
91 |
+
grid_pts[0, :, 0] = grid_query_frame
|
92 |
+
for offset in tqdm(range(grid_step * grid_step)):
|
93 |
+
ox = offset % grid_step
|
94 |
+
oy = offset // grid_step
|
95 |
+
grid_pts[0, :, 1] = (
|
96 |
+
torch.arange(grid_width).repeat(grid_height) * grid_step + ox
|
97 |
+
)
|
98 |
+
grid_pts[0, :, 2] = (
|
99 |
+
torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
|
100 |
+
)
|
101 |
+
tracks_step, visibilities_step, T_First_step = self._compute_sparse_tracks(
|
102 |
+
video=video,
|
103 |
+
queries=grid_pts,
|
104 |
+
backward_tracking=backward_tracking,
|
105 |
+
wind_length=wind_length,
|
106 |
+
video_depth=video_depth,
|
107 |
+
depth_predictor=depth_predictor,
|
108 |
+
)
|
109 |
+
tracks = smart_cat(tracks, tracks_step, dim=2)
|
110 |
+
visibilities = smart_cat(visibilities, visibilities_step, dim=2)
|
111 |
+
T_Firsts = smart_cat(T_Firsts, T_First_step, dim=1)
|
112 |
+
|
113 |
+
|
114 |
+
return tracks, visibilities, T_Firsts
|
115 |
+
|
116 |
+
def _compute_sparse_tracks(
|
117 |
+
self,
|
118 |
+
video,
|
119 |
+
queries,
|
120 |
+
segm_mask=None,
|
121 |
+
grid_size=0,
|
122 |
+
add_support_grid=False,
|
123 |
+
grid_query_frame=0,
|
124 |
+
backward_tracking=False,
|
125 |
+
depth_predictor=None,
|
126 |
+
video_depth=None,
|
127 |
+
wind_length=8,
|
128 |
+
):
|
129 |
+
B, T, C, H, W = video.shape
|
130 |
+
assert B == 1
|
131 |
+
|
132 |
+
video = video.reshape(B * T, C, H, W)
|
133 |
+
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear")
|
134 |
+
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
135 |
+
|
136 |
+
if queries is not None:
|
137 |
+
queries = queries.clone()
|
138 |
+
B, N, D = queries.shape
|
139 |
+
assert D == 3
|
140 |
+
queries[:, :, 1] *= self.interp_shape[1] / W
|
141 |
+
queries[:, :, 2] *= self.interp_shape[0] / H
|
142 |
+
elif grid_size > 0:
|
143 |
+
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
|
144 |
+
if segm_mask is not None:
|
145 |
+
segm_mask = F.interpolate(
|
146 |
+
segm_mask, tuple(self.interp_shape), mode="nearest"
|
147 |
+
)
|
148 |
+
point_mask = segm_mask[0, 0][
|
149 |
+
(grid_pts[0, :, 1]).round().long().cpu(),
|
150 |
+
(grid_pts[0, :, 0]).round().long().cpu(),
|
151 |
+
].bool()
|
152 |
+
grid_pts_extra = grid_pts[:, point_mask]
|
153 |
+
else:
|
154 |
+
grid_pts_extra = None
|
155 |
+
if grid_pts_extra is not None:
|
156 |
+
total_num = int(grid_pts_extra.shape[1])
|
157 |
+
total_num = min(800, total_num)
|
158 |
+
pick_idx = torch.randperm(grid_pts_extra.shape[1])[:total_num]
|
159 |
+
grid_pts_extra = grid_pts_extra[:, pick_idx]
|
160 |
+
queries_extra = torch.cat(
|
161 |
+
[
|
162 |
+
torch.ones_like(grid_pts_extra[:, :, :1]) * grid_query_frame,
|
163 |
+
grid_pts_extra,
|
164 |
+
],
|
165 |
+
dim=2,
|
166 |
+
)
|
167 |
+
|
168 |
+
queries = torch.cat(
|
169 |
+
[torch.zeros_like(grid_pts[:, :, :1]), grid_pts],
|
170 |
+
dim=2,
|
171 |
+
)
|
172 |
+
|
173 |
+
if add_support_grid:
|
174 |
+
grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape, device=video.device)
|
175 |
+
grid_pts = torch.cat(
|
176 |
+
[torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2
|
177 |
+
)
|
178 |
+
queries = torch.cat([queries, grid_pts], dim=1)
|
179 |
+
|
180 |
+
## ----------- estimate the video depth -----------##
|
181 |
+
if video_depth is None:
|
182 |
+
with torch.no_grad():
|
183 |
+
if video[0].shape[0]>30:
|
184 |
+
vidDepths = []
|
185 |
+
for i in range(video[0].shape[0]//30+1):
|
186 |
+
if (i+1)*30 > video[0].shape[0]:
|
187 |
+
end_idx = video[0].shape[0]
|
188 |
+
else:
|
189 |
+
end_idx = (i+1)*30
|
190 |
+
if end_idx == i*30:
|
191 |
+
break
|
192 |
+
video_ = video[0][i*30:end_idx]
|
193 |
+
vidDepths.append(depth_predictor.infer(video_/255))
|
194 |
+
|
195 |
+
video_depth = torch.cat(vidDepths, dim=0)
|
196 |
+
|
197 |
+
else:
|
198 |
+
video_depth = depth_predictor.infer(video[0]/255)
|
199 |
+
video_depth = F.interpolate(video_depth,
|
200 |
+
tuple(self.interp_shape), mode="nearest")
|
201 |
+
|
202 |
+
# from PIL import Image
|
203 |
+
# import numpy
|
204 |
+
# depth_frame = video_depth[0].detach().cpu()
|
205 |
+
# depth_frame = depth_frame.squeeze(0)
|
206 |
+
# print(depth_frame)
|
207 |
+
# print(depth_frame.min(), depth_frame.max())
|
208 |
+
# depth_img = (depth_frame * 255).numpy().astype(numpy.uint8)
|
209 |
+
# depth_img = Image.fromarray(depth_img, mode='L')
|
210 |
+
# depth_img.save('outputs/depth_map.png')
|
211 |
+
|
212 |
+
# frame = video[0, 0].detach().cpu()
|
213 |
+
# frame = frame.permute(1, 2, 0)
|
214 |
+
# frame = (frame * 255).numpy().astype(numpy.uint8)
|
215 |
+
# frame = Image.fromarray(frame, mode='RGB')
|
216 |
+
# frame.save('outputs/frame.png')
|
217 |
+
|
218 |
+
depths = video_depth
|
219 |
+
rgbds = torch.cat([video, depths[None,...]], dim=2)
|
220 |
+
# get the 3D queries
|
221 |
+
depth_interp=[]
|
222 |
+
for i in range(queries.shape[1]):
|
223 |
+
depth_interp_i = bilinear_sample2d(video_depth[queries[:, i:i+1, 0].long()],
|
224 |
+
queries[:, i:i+1, 1], queries[:, i:i+1, 2])
|
225 |
+
depth_interp.append(depth_interp_i)
|
226 |
+
|
227 |
+
depth_interp = torch.cat(depth_interp, dim=1)
|
228 |
+
queries = smart_cat(queries, depth_interp,dim=-1)
|
229 |
+
|
230 |
+
#NOTE: free the memory of depth_predictor
|
231 |
+
del depth_predictor
|
232 |
+
torch.cuda.empty_cache()
|
233 |
+
t0 = time.time()
|
234 |
+
tracks, __, visibilities = self.model(rgbds=rgbds, queries=queries, iters=6, wind_S=wind_length)
|
235 |
+
print("Time taken for inference: ", time.time()-t0)
|
236 |
+
|
237 |
+
if backward_tracking:
|
238 |
+
tracks, visibilities = self._compute_backward_tracks(
|
239 |
+
rgbds, queries, tracks, visibilities
|
240 |
+
)
|
241 |
+
if add_support_grid:
|
242 |
+
queries[:, -self.support_grid_size ** 2 :, 0] = T - 1
|
243 |
+
if add_support_grid:
|
244 |
+
tracks = tracks[:, :, : -self.support_grid_size ** 2]
|
245 |
+
visibilities = visibilities[:, :, : -self.support_grid_size ** 2]
|
246 |
+
thr = 0.9
|
247 |
+
visibilities = visibilities > thr
|
248 |
+
|
249 |
+
# correct query-point predictions
|
250 |
+
# see https://github.com/facebookresearch/co-tracker/issues/28
|
251 |
+
|
252 |
+
# TODO: batchify
|
253 |
+
for i in range(len(queries)):
|
254 |
+
queries_t = queries[i, :tracks.size(2), 0].to(torch.int64)
|
255 |
+
arange = torch.arange(0, len(queries_t))
|
256 |
+
|
257 |
+
# overwrite the predictions with the query points
|
258 |
+
tracks[i, queries_t, arange] = queries[i, :tracks.size(2), 1:]
|
259 |
+
|
260 |
+
# correct visibilities, the query points should be visible
|
261 |
+
visibilities[i, queries_t, arange] = True
|
262 |
+
|
263 |
+
T_First = queries[..., :tracks.size(2), 0].to(torch.uint8)
|
264 |
+
tracks[:, :, :, 0] *= W / float(self.interp_shape[1])
|
265 |
+
tracks[:, :, :, 1] *= H / float(self.interp_shape[0])
|
266 |
+
return tracks, visibilities, T_First
|
267 |
+
|
268 |
+
def _compute_backward_tracks(self, video, queries, tracks, visibilities):
|
269 |
+
inv_video = video.flip(1).clone()
|
270 |
+
inv_queries = queries.clone()
|
271 |
+
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
|
272 |
+
|
273 |
+
inv_tracks, __, inv_visibilities = self.model(
|
274 |
+
rgbds=inv_video, queries=queries, iters=6
|
275 |
+
)
|
276 |
+
|
277 |
+
inv_tracks = inv_tracks.flip(1)
|
278 |
+
inv_visibilities = inv_visibilities.flip(1)
|
279 |
+
|
280 |
+
mask = tracks == 0
|
281 |
+
|
282 |
+
tracks[mask] = inv_tracks[mask]
|
283 |
+
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
|
284 |
+
return tracks, visibilities
|
models/spatracker/utils/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
models/spatracker/utils/basic.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from os.path import isfile
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
EPS = 1e-6
|
7 |
+
import copy
|
8 |
+
|
9 |
+
def sub2ind(height, width, y, x):
|
10 |
+
return y*width + x
|
11 |
+
|
12 |
+
def ind2sub(height, width, ind):
|
13 |
+
y = ind // width
|
14 |
+
x = ind % width
|
15 |
+
return y, x
|
16 |
+
|
17 |
+
def get_lr_str(lr):
|
18 |
+
lrn = "%.1e" % lr # e.g., 5.0e-04
|
19 |
+
lrn = lrn[0] + lrn[3:5] + lrn[-1] # e.g., 5e-4
|
20 |
+
return lrn
|
21 |
+
|
22 |
+
def strnum(x):
|
23 |
+
s = '%g' % x
|
24 |
+
if '.' in s:
|
25 |
+
if x < 1.0:
|
26 |
+
s = s[s.index('.'):]
|
27 |
+
s = s[:min(len(s),4)]
|
28 |
+
return s
|
29 |
+
|
30 |
+
def assert_same_shape(t1, t2):
|
31 |
+
for (x, y) in zip(list(t1.shape), list(t2.shape)):
|
32 |
+
assert(x==y)
|
33 |
+
|
34 |
+
def print_stats(name, tensor):
|
35 |
+
shape = tensor.shape
|
36 |
+
tensor = tensor.detach().cpu().numpy()
|
37 |
+
print('%s (%s) min = %.2f, mean = %.2f, max = %.2f' % (name, tensor.dtype, np.min(tensor), np.mean(tensor), np.max(tensor)), shape)
|
38 |
+
|
39 |
+
def print_stats_py(name, tensor):
|
40 |
+
shape = tensor.shape
|
41 |
+
print('%s (%s) min = %.2f, mean = %.2f, max = %.2f' % (name, tensor.dtype, np.min(tensor), np.mean(tensor), np.max(tensor)), shape)
|
42 |
+
|
43 |
+
def print_(name, tensor):
|
44 |
+
tensor = tensor.detach().cpu().numpy()
|
45 |
+
print(name, tensor, tensor.shape)
|
46 |
+
|
47 |
+
def mkdir(path):
|
48 |
+
if not os.path.exists(path):
|
49 |
+
os.makedirs(path)
|
50 |
+
|
51 |
+
def normalize_single(d):
|
52 |
+
# d is a whatever shape torch tensor
|
53 |
+
dmin = torch.min(d)
|
54 |
+
dmax = torch.max(d)
|
55 |
+
d = (d-dmin)/(EPS+(dmax-dmin))
|
56 |
+
return d
|
57 |
+
|
58 |
+
def normalize(d):
|
59 |
+
# d is B x whatever. normalize within each element of the batch
|
60 |
+
out = torch.zeros(d.size())
|
61 |
+
if d.is_cuda:
|
62 |
+
out = out.cuda()
|
63 |
+
B = list(d.size())[0]
|
64 |
+
for b in list(range(B)):
|
65 |
+
out[b] = normalize_single(d[b])
|
66 |
+
return out
|
67 |
+
|
68 |
+
def hard_argmax2d(tensor):
|
69 |
+
B, C, Y, X = list(tensor.shape)
|
70 |
+
assert(C==1)
|
71 |
+
|
72 |
+
# flatten the Tensor along the height and width axes
|
73 |
+
flat_tensor = tensor.reshape(B, -1)
|
74 |
+
# argmax of the flat tensor
|
75 |
+
argmax = torch.argmax(flat_tensor, dim=1)
|
76 |
+
|
77 |
+
# convert the indices into 2d coordinates
|
78 |
+
argmax_y = torch.floor(argmax / X) # row
|
79 |
+
argmax_x = argmax % X # col
|
80 |
+
|
81 |
+
argmax_y = argmax_y.reshape(B)
|
82 |
+
argmax_x = argmax_x.reshape(B)
|
83 |
+
return argmax_y, argmax_x
|
84 |
+
|
85 |
+
def argmax2d(heat, hard=True):
|
86 |
+
B, C, Y, X = list(heat.shape)
|
87 |
+
assert(C==1)
|
88 |
+
|
89 |
+
if hard:
|
90 |
+
# hard argmax
|
91 |
+
loc_y, loc_x = hard_argmax2d(heat)
|
92 |
+
loc_y = loc_y.float()
|
93 |
+
loc_x = loc_x.float()
|
94 |
+
else:
|
95 |
+
heat = heat.reshape(B, Y*X)
|
96 |
+
prob = torch.nn.functional.softmax(heat, dim=1)
|
97 |
+
|
98 |
+
grid_y, grid_x = meshgrid2d(B, Y, X)
|
99 |
+
|
100 |
+
grid_y = grid_y.reshape(B, -1)
|
101 |
+
grid_x = grid_x.reshape(B, -1)
|
102 |
+
|
103 |
+
loc_y = torch.sum(grid_y*prob, dim=1)
|
104 |
+
loc_x = torch.sum(grid_x*prob, dim=1)
|
105 |
+
# these are B
|
106 |
+
|
107 |
+
return loc_y, loc_x
|
108 |
+
|
109 |
+
def reduce_masked_mean(x, mask, dim=None, keepdim=False):
|
110 |
+
# x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting
|
111 |
+
# returns shape-1
|
112 |
+
# axis can be a list of axes
|
113 |
+
for (a,b) in zip(x.size(), mask.size()):
|
114 |
+
# if not b==1:
|
115 |
+
assert(a==b) # some shape mismatch!
|
116 |
+
# assert(x.size() == mask.size())
|
117 |
+
prod = x*mask
|
118 |
+
if dim is None:
|
119 |
+
numer = torch.sum(prod)
|
120 |
+
denom = EPS+torch.sum(mask)
|
121 |
+
else:
|
122 |
+
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
|
123 |
+
denom = EPS+torch.sum(mask, dim=dim, keepdim=keepdim)
|
124 |
+
|
125 |
+
mean = numer/denom
|
126 |
+
return mean
|
127 |
+
|
128 |
+
def reduce_masked_median(x, mask, keep_batch=False):
|
129 |
+
# x and mask are the same shape
|
130 |
+
assert(x.size() == mask.size())
|
131 |
+
device = x.device
|
132 |
+
|
133 |
+
B = list(x.shape)[0]
|
134 |
+
x = x.detach().cpu().numpy()
|
135 |
+
mask = mask.detach().cpu().numpy()
|
136 |
+
|
137 |
+
if keep_batch:
|
138 |
+
x = np.reshape(x, [B, -1])
|
139 |
+
mask = np.reshape(mask, [B, -1])
|
140 |
+
meds = np.zeros([B], np.float32)
|
141 |
+
for b in list(range(B)):
|
142 |
+
xb = x[b]
|
143 |
+
mb = mask[b]
|
144 |
+
if np.sum(mb) > 0:
|
145 |
+
xb = xb[mb > 0]
|
146 |
+
meds[b] = np.median(xb)
|
147 |
+
else:
|
148 |
+
meds[b] = np.nan
|
149 |
+
meds = torch.from_numpy(meds).to(device)
|
150 |
+
return meds.float()
|
151 |
+
else:
|
152 |
+
x = np.reshape(x, [-1])
|
153 |
+
mask = np.reshape(mask, [-1])
|
154 |
+
if np.sum(mask) > 0:
|
155 |
+
x = x[mask > 0]
|
156 |
+
med = np.median(x)
|
157 |
+
else:
|
158 |
+
med = np.nan
|
159 |
+
med = np.array([med], np.float32)
|
160 |
+
med = torch.from_numpy(med).to(device)
|
161 |
+
return med.float()
|
162 |
+
|
163 |
+
def pack_seqdim(tensor, B):
|
164 |
+
shapelist = list(tensor.shape)
|
165 |
+
B_, S = shapelist[:2]
|
166 |
+
assert(B==B_)
|
167 |
+
otherdims = shapelist[2:]
|
168 |
+
tensor = torch.reshape(tensor, [B*S]+otherdims)
|
169 |
+
return tensor
|
170 |
+
|
171 |
+
def unpack_seqdim(tensor, B):
|
172 |
+
shapelist = list(tensor.shape)
|
173 |
+
BS = shapelist[0]
|
174 |
+
assert(BS%B==0)
|
175 |
+
otherdims = shapelist[1:]
|
176 |
+
S = int(BS/B)
|
177 |
+
tensor = torch.reshape(tensor, [B,S]+otherdims)
|
178 |
+
return tensor
|
179 |
+
|
180 |
+
def meshgrid2d(B, Y, X, stack=False, norm=False, device='cuda', on_chans=False):
|
181 |
+
# returns a meshgrid sized B x Y x X
|
182 |
+
|
183 |
+
grid_y = torch.linspace(0.0, Y-1, Y, device=torch.device(device))
|
184 |
+
grid_y = torch.reshape(grid_y, [1, Y, 1])
|
185 |
+
grid_y = grid_y.repeat(B, 1, X)
|
186 |
+
|
187 |
+
grid_x = torch.linspace(0.0, X-1, X, device=torch.device(device))
|
188 |
+
grid_x = torch.reshape(grid_x, [1, 1, X])
|
189 |
+
grid_x = grid_x.repeat(B, Y, 1)
|
190 |
+
|
191 |
+
if norm:
|
192 |
+
grid_y, grid_x = normalize_grid2d(
|
193 |
+
grid_y, grid_x, Y, X)
|
194 |
+
|
195 |
+
if stack:
|
196 |
+
# note we stack in xy order
|
197 |
+
# (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
|
198 |
+
if on_chans:
|
199 |
+
grid = torch.stack([grid_x, grid_y], dim=1)
|
200 |
+
else:
|
201 |
+
grid = torch.stack([grid_x, grid_y], dim=-1)
|
202 |
+
return grid
|
203 |
+
else:
|
204 |
+
return grid_y, grid_x
|
205 |
+
|
206 |
+
def meshgrid3d(B, Z, Y, X, stack=False, norm=False, device='cuda'):
|
207 |
+
# returns a meshgrid sized B x Z x Y x X
|
208 |
+
|
209 |
+
grid_z = torch.linspace(0.0, Z-1, Z, device=device)
|
210 |
+
grid_z = torch.reshape(grid_z, [1, Z, 1, 1])
|
211 |
+
grid_z = grid_z.repeat(B, 1, Y, X)
|
212 |
+
|
213 |
+
grid_y = torch.linspace(0.0, Y-1, Y, device=device)
|
214 |
+
grid_y = torch.reshape(grid_y, [1, 1, Y, 1])
|
215 |
+
grid_y = grid_y.repeat(B, Z, 1, X)
|
216 |
+
|
217 |
+
grid_x = torch.linspace(0.0, X-1, X, device=device)
|
218 |
+
grid_x = torch.reshape(grid_x, [1, 1, 1, X])
|
219 |
+
grid_x = grid_x.repeat(B, Z, Y, 1)
|
220 |
+
|
221 |
+
# if cuda:
|
222 |
+
# grid_z = grid_z.cuda()
|
223 |
+
# grid_y = grid_y.cuda()
|
224 |
+
# grid_x = grid_x.cuda()
|
225 |
+
|
226 |
+
if norm:
|
227 |
+
grid_z, grid_y, grid_x = normalize_grid3d(
|
228 |
+
grid_z, grid_y, grid_x, Z, Y, X)
|
229 |
+
|
230 |
+
if stack:
|
231 |
+
# note we stack in xyz order
|
232 |
+
# (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
|
233 |
+
grid = torch.stack([grid_x, grid_y, grid_z], dim=-1)
|
234 |
+
return grid
|
235 |
+
else:
|
236 |
+
return grid_z, grid_y, grid_x
|
237 |
+
|
238 |
+
def normalize_grid2d(grid_y, grid_x, Y, X, clamp_extreme=True):
|
239 |
+
# make things in [-1,1]
|
240 |
+
grid_y = 2.0*(grid_y / float(Y-1)) - 1.0
|
241 |
+
grid_x = 2.0*(grid_x / float(X-1)) - 1.0
|
242 |
+
|
243 |
+
if clamp_extreme:
|
244 |
+
grid_y = torch.clamp(grid_y, min=-2.0, max=2.0)
|
245 |
+
grid_x = torch.clamp(grid_x, min=-2.0, max=2.0)
|
246 |
+
|
247 |
+
return grid_y, grid_x
|
248 |
+
|
249 |
+
def normalize_grid3d(grid_z, grid_y, grid_x, Z, Y, X, clamp_extreme=True):
|
250 |
+
# make things in [-1,1]
|
251 |
+
grid_z = 2.0*(grid_z / float(Z-1)) - 1.0
|
252 |
+
grid_y = 2.0*(grid_y / float(Y-1)) - 1.0
|
253 |
+
grid_x = 2.0*(grid_x / float(X-1)) - 1.0
|
254 |
+
|
255 |
+
if clamp_extreme:
|
256 |
+
grid_z = torch.clamp(grid_z, min=-2.0, max=2.0)
|
257 |
+
grid_y = torch.clamp(grid_y, min=-2.0, max=2.0)
|
258 |
+
grid_x = torch.clamp(grid_x, min=-2.0, max=2.0)
|
259 |
+
|
260 |
+
return grid_z, grid_y, grid_x
|
261 |
+
|
262 |
+
def gridcloud2d(B, Y, X, norm=False, device='cuda'):
|
263 |
+
# we want to sample for each location in the grid
|
264 |
+
grid_y, grid_x = meshgrid2d(B, Y, X, norm=norm, device=device)
|
265 |
+
x = torch.reshape(grid_x, [B, -1])
|
266 |
+
y = torch.reshape(grid_y, [B, -1])
|
267 |
+
# these are B x N
|
268 |
+
xy = torch.stack([x, y], dim=2)
|
269 |
+
# this is B x N x 2
|
270 |
+
return xy
|
271 |
+
|
272 |
+
def gridcloud3d(B, Z, Y, X, norm=False, device='cuda'):
|
273 |
+
# we want to sample for each location in the grid
|
274 |
+
grid_z, grid_y, grid_x = meshgrid3d(B, Z, Y, X, norm=norm, device=device)
|
275 |
+
x = torch.reshape(grid_x, [B, -1])
|
276 |
+
y = torch.reshape(grid_y, [B, -1])
|
277 |
+
z = torch.reshape(grid_z, [B, -1])
|
278 |
+
# these are B x N
|
279 |
+
xyz = torch.stack([x, y, z], dim=2)
|
280 |
+
# this is B x N x 3
|
281 |
+
return xyz
|
282 |
+
|
283 |
+
import re
|
284 |
+
def readPFM(file):
|
285 |
+
file = open(file, 'rb')
|
286 |
+
|
287 |
+
color = None
|
288 |
+
width = None
|
289 |
+
height = None
|
290 |
+
scale = None
|
291 |
+
endian = None
|
292 |
+
|
293 |
+
header = file.readline().rstrip()
|
294 |
+
if header == b'PF':
|
295 |
+
color = True
|
296 |
+
elif header == b'Pf':
|
297 |
+
color = False
|
298 |
+
else:
|
299 |
+
raise Exception('Not a PFM file.')
|
300 |
+
|
301 |
+
dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
|
302 |
+
if dim_match:
|
303 |
+
width, height = map(int, dim_match.groups())
|
304 |
+
else:
|
305 |
+
raise Exception('Malformed PFM header.')
|
306 |
+
|
307 |
+
scale = float(file.readline().rstrip())
|
308 |
+
if scale < 0: # little-endian
|
309 |
+
endian = '<'
|
310 |
+
scale = -scale
|
311 |
+
else:
|
312 |
+
endian = '>' # big-endian
|
313 |
+
|
314 |
+
data = np.fromfile(file, endian + 'f')
|
315 |
+
shape = (height, width, 3) if color else (height, width)
|
316 |
+
|
317 |
+
data = np.reshape(data, shape)
|
318 |
+
data = np.flipud(data)
|
319 |
+
return data
|
320 |
+
|
321 |
+
def normalize_boxlist2d(boxlist2d, H, W):
|
322 |
+
boxlist2d = boxlist2d.clone()
|
323 |
+
ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2)
|
324 |
+
ymin = ymin / float(H)
|
325 |
+
ymax = ymax / float(H)
|
326 |
+
xmin = xmin / float(W)
|
327 |
+
xmax = xmax / float(W)
|
328 |
+
boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2)
|
329 |
+
return boxlist2d
|
330 |
+
|
331 |
+
def unnormalize_boxlist2d(boxlist2d, H, W):
|
332 |
+
boxlist2d = boxlist2d.clone()
|
333 |
+
ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2)
|
334 |
+
ymin = ymin * float(H)
|
335 |
+
ymax = ymax * float(H)
|
336 |
+
xmin = xmin * float(W)
|
337 |
+
xmax = xmax * float(W)
|
338 |
+
boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2)
|
339 |
+
return boxlist2d
|
340 |
+
|
341 |
+
def unnormalize_box2d(box2d, H, W):
|
342 |
+
return unnormalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1)
|
343 |
+
|
344 |
+
def normalize_box2d(box2d, H, W):
|
345 |
+
return normalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1)
|
346 |
+
|
347 |
+
def get_gaussian_kernel_2d(channels, kernel_size=3, sigma=2.0, mid_one=False):
|
348 |
+
C = channels
|
349 |
+
xy_grid = gridcloud2d(C, kernel_size, kernel_size) # C x N x 2
|
350 |
+
|
351 |
+
mean = (kernel_size - 1)/2.0
|
352 |
+
variance = sigma**2.0
|
353 |
+
|
354 |
+
gaussian_kernel = (1.0/(2.0*np.pi*variance)**1.5) * torch.exp(-torch.sum((xy_grid - mean)**2.0, dim=-1) / (2.0*variance)) # C X N
|
355 |
+
gaussian_kernel = gaussian_kernel.view(C, 1, kernel_size, kernel_size) # C x 1 x 3 x 3
|
356 |
+
kernel_sum = torch.sum(gaussian_kernel, dim=(2,3), keepdim=True)
|
357 |
+
|
358 |
+
gaussian_kernel = gaussian_kernel / kernel_sum # normalize
|
359 |
+
|
360 |
+
if mid_one:
|
361 |
+
# normalize so that the middle element is 1
|
362 |
+
maxval = gaussian_kernel[:,:,(kernel_size//2),(kernel_size//2)].reshape(C, 1, 1, 1)
|
363 |
+
gaussian_kernel = gaussian_kernel / maxval
|
364 |
+
|
365 |
+
return gaussian_kernel
|
366 |
+
|
367 |
+
def gaussian_blur_2d(input, kernel_size=3, sigma=2.0, reflect_pad=False, mid_one=False):
|
368 |
+
B, C, Z, X = input.shape
|
369 |
+
kernel = get_gaussian_kernel_2d(C, kernel_size, sigma, mid_one=mid_one)
|
370 |
+
if reflect_pad:
|
371 |
+
pad = (kernel_size - 1)//2
|
372 |
+
out = F.pad(input, (pad, pad, pad, pad), mode='reflect')
|
373 |
+
out = F.conv2d(out, kernel, padding=0, groups=C)
|
374 |
+
else:
|
375 |
+
out = F.conv2d(input, kernel, padding=(kernel_size - 1)//2, groups=C)
|
376 |
+
return out
|
377 |
+
|
378 |
+
def gradient2d(x, absolute=False, square=False, return_sum=False):
|
379 |
+
# x should be B x C x H x W
|
380 |
+
dh = x[:, :, 1:, :] - x[:, :, :-1, :]
|
381 |
+
dw = x[:, :, :, 1:] - x[:, :, :, :-1]
|
382 |
+
|
383 |
+
zeros = torch.zeros_like(x)
|
384 |
+
zero_h = zeros[:, :, 0:1, :]
|
385 |
+
zero_w = zeros[:, :, :, 0:1]
|
386 |
+
dh = torch.cat([dh, zero_h], axis=2)
|
387 |
+
dw = torch.cat([dw, zero_w], axis=3)
|
388 |
+
if absolute:
|
389 |
+
dh = torch.abs(dh)
|
390 |
+
dw = torch.abs(dw)
|
391 |
+
if square:
|
392 |
+
dh = dh ** 2
|
393 |
+
dw = dw ** 2
|
394 |
+
if return_sum:
|
395 |
+
return dh+dw
|
396 |
+
else:
|
397 |
+
return dh, dw
|
models/spatracker/utils/geom.py
ADDED
@@ -0,0 +1,547 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import models.spatracker.utils.basic
|
3 |
+
import numpy as np
|
4 |
+
import torchvision.ops as ops
|
5 |
+
from models.spatracker.utils.basic import print_
|
6 |
+
|
7 |
+
def matmul2(mat1, mat2):
|
8 |
+
return torch.matmul(mat1, mat2)
|
9 |
+
|
10 |
+
def matmul3(mat1, mat2, mat3):
|
11 |
+
return torch.matmul(mat1, torch.matmul(mat2, mat3))
|
12 |
+
|
13 |
+
def eye_3x3(B, device='cuda'):
|
14 |
+
rt = torch.eye(3, device=torch.device(device)).view(1,3,3).repeat([B, 1, 1])
|
15 |
+
return rt
|
16 |
+
|
17 |
+
def eye_4x4(B, device='cuda'):
|
18 |
+
rt = torch.eye(4, device=torch.device(device)).view(1,4,4).repeat([B, 1, 1])
|
19 |
+
return rt
|
20 |
+
|
21 |
+
def safe_inverse(a): #parallel version
|
22 |
+
B, _, _ = list(a.shape)
|
23 |
+
inv = a.clone()
|
24 |
+
r_transpose = a[:, :3, :3].transpose(1,2) #inverse of rotation matrix
|
25 |
+
|
26 |
+
inv[:, :3, :3] = r_transpose
|
27 |
+
inv[:, :3, 3:4] = -torch.matmul(r_transpose, a[:, :3, 3:4])
|
28 |
+
|
29 |
+
return inv
|
30 |
+
|
31 |
+
def safe_inverse_single(a):
|
32 |
+
r, t = split_rt_single(a)
|
33 |
+
t = t.view(3,1)
|
34 |
+
r_transpose = r.t()
|
35 |
+
inv = torch.cat([r_transpose, -torch.matmul(r_transpose, t)], 1)
|
36 |
+
bottom_row = a[3:4, :] # this is [0, 0, 0, 1]
|
37 |
+
# bottom_row = torch.tensor([0.,0.,0.,1.]).view(1,4)
|
38 |
+
inv = torch.cat([inv, bottom_row], 0)
|
39 |
+
return inv
|
40 |
+
|
41 |
+
def split_intrinsics(K):
|
42 |
+
# K is B x 3 x 3 or B x 4 x 4
|
43 |
+
fx = K[:,0,0]
|
44 |
+
fy = K[:,1,1]
|
45 |
+
x0 = K[:,0,2]
|
46 |
+
y0 = K[:,1,2]
|
47 |
+
return fx, fy, x0, y0
|
48 |
+
|
49 |
+
def apply_pix_T_cam(pix_T_cam, xyz):
|
50 |
+
|
51 |
+
fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
|
52 |
+
|
53 |
+
# xyz is shaped B x H*W x 3
|
54 |
+
# returns xy, shaped B x H*W x 2
|
55 |
+
|
56 |
+
B, N, C = list(xyz.shape)
|
57 |
+
assert(C==3)
|
58 |
+
|
59 |
+
x, y, z = torch.unbind(xyz, axis=-1)
|
60 |
+
|
61 |
+
fx = torch.reshape(fx, [B, 1])
|
62 |
+
fy = torch.reshape(fy, [B, 1])
|
63 |
+
x0 = torch.reshape(x0, [B, 1])
|
64 |
+
y0 = torch.reshape(y0, [B, 1])
|
65 |
+
|
66 |
+
EPS = 1e-4
|
67 |
+
z = torch.clamp(z, min=EPS)
|
68 |
+
x = (x*fx)/(z)+x0
|
69 |
+
y = (y*fy)/(z)+y0
|
70 |
+
xy = torch.stack([x, y], axis=-1)
|
71 |
+
return xy
|
72 |
+
|
73 |
+
def apply_pix_T_cam_py(pix_T_cam, xyz):
|
74 |
+
|
75 |
+
fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
|
76 |
+
|
77 |
+
# xyz is shaped B x H*W x 3
|
78 |
+
# returns xy, shaped B x H*W x 2
|
79 |
+
|
80 |
+
B, N, C = list(xyz.shape)
|
81 |
+
assert(C==3)
|
82 |
+
|
83 |
+
x, y, z = xyz[:,:,0], xyz[:,:,1], xyz[:,:,2]
|
84 |
+
|
85 |
+
fx = np.reshape(fx, [B, 1])
|
86 |
+
fy = np.reshape(fy, [B, 1])
|
87 |
+
x0 = np.reshape(x0, [B, 1])
|
88 |
+
y0 = np.reshape(y0, [B, 1])
|
89 |
+
|
90 |
+
EPS = 1e-4
|
91 |
+
z = np.clip(z, EPS, None)
|
92 |
+
x = (x*fx)/(z)+x0
|
93 |
+
y = (y*fy)/(z)+y0
|
94 |
+
xy = np.stack([x, y], axis=-1)
|
95 |
+
return xy
|
96 |
+
|
97 |
+
def get_camM_T_camXs(origin_T_camXs, ind=0):
|
98 |
+
B, S = list(origin_T_camXs.shape)[0:2]
|
99 |
+
camM_T_camXs = torch.zeros_like(origin_T_camXs)
|
100 |
+
for b in list(range(B)):
|
101 |
+
camM_T_origin = safe_inverse_single(origin_T_camXs[b,ind])
|
102 |
+
for s in list(range(S)):
|
103 |
+
camM_T_camXs[b,s] = torch.matmul(camM_T_origin, origin_T_camXs[b,s])
|
104 |
+
return camM_T_camXs
|
105 |
+
|
106 |
+
def apply_4x4(RT, xyz):
|
107 |
+
B, N, _ = list(xyz.shape)
|
108 |
+
ones = torch.ones_like(xyz[:,:,0:1])
|
109 |
+
xyz1 = torch.cat([xyz, ones], 2)
|
110 |
+
xyz1_t = torch.transpose(xyz1, 1, 2)
|
111 |
+
# this is B x 4 x N
|
112 |
+
xyz2_t = torch.matmul(RT, xyz1_t)
|
113 |
+
xyz2 = torch.transpose(xyz2_t, 1, 2)
|
114 |
+
xyz2 = xyz2[:,:,:3]
|
115 |
+
return xyz2
|
116 |
+
|
117 |
+
def apply_4x4_py(RT, xyz):
|
118 |
+
# print('RT', RT.shape)
|
119 |
+
B, N, _ = list(xyz.shape)
|
120 |
+
ones = np.ones_like(xyz[:,:,0:1])
|
121 |
+
xyz1 = np.concatenate([xyz, ones], 2)
|
122 |
+
# print('xyz1', xyz1.shape)
|
123 |
+
xyz1_t = xyz1.transpose(0,2,1)
|
124 |
+
# print('xyz1_t', xyz1_t.shape)
|
125 |
+
# this is B x 4 x N
|
126 |
+
xyz2_t = np.matmul(RT, xyz1_t)
|
127 |
+
# print('xyz2_t', xyz2_t.shape)
|
128 |
+
xyz2 = xyz2_t.transpose(0,2,1)
|
129 |
+
# print('xyz2', xyz2.shape)
|
130 |
+
xyz2 = xyz2[:,:,:3]
|
131 |
+
return xyz2
|
132 |
+
|
133 |
+
def apply_3x3(RT, xy):
|
134 |
+
B, N, _ = list(xy.shape)
|
135 |
+
ones = torch.ones_like(xy[:,:,0:1])
|
136 |
+
xy1 = torch.cat([xy, ones], 2)
|
137 |
+
xy1_t = torch.transpose(xy1, 1, 2)
|
138 |
+
# this is B x 4 x N
|
139 |
+
xy2_t = torch.matmul(RT, xy1_t)
|
140 |
+
xy2 = torch.transpose(xy2_t, 1, 2)
|
141 |
+
xy2 = xy2[:,:,:2]
|
142 |
+
return xy2
|
143 |
+
|
144 |
+
def generate_polygon(ctr_x, ctr_y, avg_r, irregularity, spikiness, num_verts):
|
145 |
+
'''
|
146 |
+
Start with the center of the polygon at ctr_x, ctr_y,
|
147 |
+
Then creates the polygon by sampling points on a circle around the center.
|
148 |
+
Random noise is added by varying the angular spacing between sequential points,
|
149 |
+
and by varying the radial distance of each point from the centre.
|
150 |
+
|
151 |
+
Params:
|
152 |
+
ctr_x, ctr_y - coordinates of the "centre" of the polygon
|
153 |
+
avg_r - in px, the average radius of this polygon, this roughly controls how large the polygon is, really only useful for order of magnitude.
|
154 |
+
irregularity - [0,1] indicating how much variance there is in the angular spacing of vertices. [0,1] will map to [0, 2pi/numberOfVerts]
|
155 |
+
spikiness - [0,1] indicating how much variance there is in each vertex from the circle of radius avg_r. [0,1] will map to [0, avg_r]
|
156 |
+
pp num_verts
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
np.array [num_verts, 2] - CCW order.
|
160 |
+
'''
|
161 |
+
# spikiness
|
162 |
+
spikiness = np.clip(spikiness, 0, 1) * avg_r
|
163 |
+
|
164 |
+
# generate n angle steps
|
165 |
+
irregularity = np.clip(irregularity, 0, 1) * 2 * np.pi / num_verts
|
166 |
+
lower = (2*np.pi / num_verts) - irregularity
|
167 |
+
upper = (2*np.pi / num_verts) + irregularity
|
168 |
+
|
169 |
+
# angle steps
|
170 |
+
angle_steps = np.random.uniform(lower, upper, num_verts)
|
171 |
+
sc = (2 * np.pi) / angle_steps.sum()
|
172 |
+
angle_steps *= sc
|
173 |
+
|
174 |
+
# get all radii
|
175 |
+
angle = np.random.uniform(0, 2*np.pi)
|
176 |
+
radii = np.clip(np.random.normal(avg_r, spikiness, num_verts), 0, 2 * avg_r)
|
177 |
+
|
178 |
+
# compute all points
|
179 |
+
points = []
|
180 |
+
for i in range(num_verts):
|
181 |
+
x = ctr_x + radii[i] * np.cos(angle)
|
182 |
+
y = ctr_y + radii[i] * np.sin(angle)
|
183 |
+
points.append([x, y])
|
184 |
+
angle += angle_steps[i]
|
185 |
+
|
186 |
+
return np.array(points).astype(int)
|
187 |
+
|
188 |
+
|
189 |
+
def get_random_affine_2d(B, rot_min=-5.0, rot_max=5.0, tx_min=-0.1, tx_max=0.1, ty_min=-0.1, ty_max=0.1, sx_min=-0.05, sx_max=0.05, sy_min=-0.05, sy_max=0.05, shx_min=-0.05, shx_max=0.05, shy_min=-0.05, shy_max=0.05):
|
190 |
+
'''
|
191 |
+
Params:
|
192 |
+
rot_min: rotation amount min
|
193 |
+
rot_max: rotation amount max
|
194 |
+
|
195 |
+
tx_min: translation x min
|
196 |
+
tx_max: translation x max
|
197 |
+
|
198 |
+
ty_min: translation y min
|
199 |
+
ty_max: translation y max
|
200 |
+
|
201 |
+
sx_min: scaling x min
|
202 |
+
sx_max: scaling x max
|
203 |
+
|
204 |
+
sy_min: scaling y min
|
205 |
+
sy_max: scaling y max
|
206 |
+
|
207 |
+
shx_min: shear x min
|
208 |
+
shx_max: shear x max
|
209 |
+
|
210 |
+
shy_min: shear y min
|
211 |
+
shy_max: shear y max
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
transformation matrix: (B, 3, 3)
|
215 |
+
'''
|
216 |
+
# rotation
|
217 |
+
if rot_max - rot_min != 0:
|
218 |
+
rot_amount = np.random.uniform(low=rot_min, high=rot_max, size=B)
|
219 |
+
rot_amount = np.pi/180.0*rot_amount
|
220 |
+
else:
|
221 |
+
rot_amount = rot_min
|
222 |
+
rotation = np.zeros((B, 3, 3)) # B, 3, 3
|
223 |
+
rotation[:, 2, 2] = 1
|
224 |
+
rotation[:, 0, 0] = np.cos(rot_amount)
|
225 |
+
rotation[:, 0, 1] = -np.sin(rot_amount)
|
226 |
+
rotation[:, 1, 0] = np.sin(rot_amount)
|
227 |
+
rotation[:, 1, 1] = np.cos(rot_amount)
|
228 |
+
|
229 |
+
# translation
|
230 |
+
translation = np.zeros((B, 3, 3)) # B, 3, 3
|
231 |
+
translation[:, [0,1,2], [0,1,2]] = 1
|
232 |
+
if (tx_max - tx_min) > 0:
|
233 |
+
trans_x = np.random.uniform(low=tx_min, high=tx_max, size=B)
|
234 |
+
translation[:, 0, 2] = trans_x
|
235 |
+
# else:
|
236 |
+
# translation[:, 0, 2] = tx_max
|
237 |
+
if ty_max - ty_min != 0:
|
238 |
+
trans_y = np.random.uniform(low=ty_min, high=ty_max, size=B)
|
239 |
+
translation[:, 1, 2] = trans_y
|
240 |
+
# else:
|
241 |
+
# translation[:, 1, 2] = ty_max
|
242 |
+
|
243 |
+
# scaling
|
244 |
+
scaling = np.zeros((B, 3, 3)) # B, 3, 3
|
245 |
+
scaling[:, [0,1,2], [0,1,2]] = 1
|
246 |
+
if (sx_max - sx_min) > 0:
|
247 |
+
scale_x = 1 + np.random.uniform(low=sx_min, high=sx_max, size=B)
|
248 |
+
scaling[:, 0, 0] = scale_x
|
249 |
+
# else:
|
250 |
+
# scaling[:, 0, 0] = sx_max
|
251 |
+
if (sy_max - sy_min) > 0:
|
252 |
+
scale_y = 1 + np.random.uniform(low=sy_min, high=sy_max, size=B)
|
253 |
+
scaling[:, 1, 1] = scale_y
|
254 |
+
# else:
|
255 |
+
# scaling[:, 1, 1] = sy_max
|
256 |
+
|
257 |
+
# shear
|
258 |
+
shear = np.zeros((B, 3, 3)) # B, 3, 3
|
259 |
+
shear[:, [0,1,2], [0,1,2]] = 1
|
260 |
+
if (shx_max - shx_min) > 0:
|
261 |
+
shear_x = np.random.uniform(low=shx_min, high=shx_max, size=B)
|
262 |
+
shear[:, 0, 1] = shear_x
|
263 |
+
# else:
|
264 |
+
# shear[:, 0, 1] = shx_max
|
265 |
+
if (shy_max - shy_min) > 0:
|
266 |
+
shear_y = np.random.uniform(low=shy_min, high=shy_max, size=B)
|
267 |
+
shear[:, 1, 0] = shear_y
|
268 |
+
# else:
|
269 |
+
# shear[:, 1, 0] = shy_max
|
270 |
+
|
271 |
+
# compose all those
|
272 |
+
rt = np.einsum("ijk,ikl->ijl", rotation, translation)
|
273 |
+
ss = np.einsum("ijk,ikl->ijl", scaling, shear)
|
274 |
+
trans = np.einsum("ijk,ikl->ijl", rt, ss)
|
275 |
+
|
276 |
+
return trans
|
277 |
+
|
278 |
+
def get_centroid_from_box2d(box2d):
|
279 |
+
ymin = box2d[:,0]
|
280 |
+
xmin = box2d[:,1]
|
281 |
+
ymax = box2d[:,2]
|
282 |
+
xmax = box2d[:,3]
|
283 |
+
x = (xmin+xmax)/2.0
|
284 |
+
y = (ymin+ymax)/2.0
|
285 |
+
return y, x
|
286 |
+
|
287 |
+
def normalize_boxlist2d(boxlist2d, H, W):
|
288 |
+
boxlist2d = boxlist2d.clone()
|
289 |
+
ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2)
|
290 |
+
ymin = ymin / float(H)
|
291 |
+
ymax = ymax / float(H)
|
292 |
+
xmin = xmin / float(W)
|
293 |
+
xmax = xmax / float(W)
|
294 |
+
boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2)
|
295 |
+
return boxlist2d
|
296 |
+
|
297 |
+
def unnormalize_boxlist2d(boxlist2d, H, W):
|
298 |
+
boxlist2d = boxlist2d.clone()
|
299 |
+
ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2)
|
300 |
+
ymin = ymin * float(H)
|
301 |
+
ymax = ymax * float(H)
|
302 |
+
xmin = xmin * float(W)
|
303 |
+
xmax = xmax * float(W)
|
304 |
+
boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2)
|
305 |
+
return boxlist2d
|
306 |
+
|
307 |
+
def unnormalize_box2d(box2d, H, W):
|
308 |
+
return unnormalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1)
|
309 |
+
|
310 |
+
def normalize_box2d(box2d, H, W):
|
311 |
+
return normalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1)
|
312 |
+
|
313 |
+
def get_size_from_box2d(box2d):
|
314 |
+
ymin = box2d[:,0]
|
315 |
+
xmin = box2d[:,1]
|
316 |
+
ymax = box2d[:,2]
|
317 |
+
xmax = box2d[:,3]
|
318 |
+
height = ymax-ymin
|
319 |
+
width = xmax-xmin
|
320 |
+
return height, width
|
321 |
+
|
322 |
+
def crop_and_resize(im, boxlist, PH, PW, boxlist_is_normalized=False):
|
323 |
+
B, C, H, W = im.shape
|
324 |
+
B2, N, D = boxlist.shape
|
325 |
+
assert(B==B2)
|
326 |
+
assert(D==4)
|
327 |
+
# PH, PW is the size to resize to
|
328 |
+
|
329 |
+
# output is B,N,C,PH,PW
|
330 |
+
|
331 |
+
# pt wants xy xy, unnormalized
|
332 |
+
if boxlist_is_normalized:
|
333 |
+
boxlist_unnorm = unnormalize_boxlist2d(boxlist, H, W)
|
334 |
+
else:
|
335 |
+
boxlist_unnorm = boxlist
|
336 |
+
|
337 |
+
ymin, xmin, ymax, xmax = boxlist_unnorm.unbind(2)
|
338 |
+
# boxlist_pt = torch.stack([boxlist_unnorm[:,1], boxlist_unnorm[:,0], boxlist_unnorm[:,3], boxlist_unnorm[:,2]], dim=1)
|
339 |
+
boxlist_pt = torch.stack([xmin, ymin, xmax, ymax], dim=2)
|
340 |
+
# we want a B-len list of K x 4 arrays
|
341 |
+
|
342 |
+
# print('im', im.shape)
|
343 |
+
# print('boxlist', boxlist.shape)
|
344 |
+
# print('boxlist_pt', boxlist_pt.shape)
|
345 |
+
|
346 |
+
# boxlist_pt = list(boxlist_pt.unbind(0))
|
347 |
+
|
348 |
+
crops = []
|
349 |
+
for b in range(B):
|
350 |
+
crops_b = ops.roi_align(im[b:b+1], [boxlist_pt[b]], output_size=(PH, PW))
|
351 |
+
crops.append(crops_b)
|
352 |
+
# # crops = im
|
353 |
+
|
354 |
+
# print('crops', crops.shape)
|
355 |
+
# crops = crops.reshape(B,N,C,PH,PW)
|
356 |
+
|
357 |
+
|
358 |
+
# crops = []
|
359 |
+
# for b in range(B):
|
360 |
+
# crop_b = ops.roi_align(im[b:b+1], [boxlist_pt[b]], output_size=(PH, PW))
|
361 |
+
# print('crop_b', crop_b.shape)
|
362 |
+
# crops.append(crop_b)
|
363 |
+
crops = torch.stack(crops, dim=0)
|
364 |
+
|
365 |
+
# print('crops', crops.shape)
|
366 |
+
# boxlist_list = boxlist_pt.unbind(0)
|
367 |
+
# print('rgb_crop', rgb_crop.shape)
|
368 |
+
|
369 |
+
return crops
|
370 |
+
|
371 |
+
|
372 |
+
# def get_boxlist_from_centroid_and_size(cy, cx, h, w, clip=True):
|
373 |
+
# # cy,cx are both B,N
|
374 |
+
# ymin = cy - h/2
|
375 |
+
# ymax = cy + h/2
|
376 |
+
# xmin = cx - w/2
|
377 |
+
# xmax = cx + w/2
|
378 |
+
|
379 |
+
# box = torch.stack([ymin, xmin, ymax, xmax], dim=-1)
|
380 |
+
# if clip:
|
381 |
+
# box = torch.clamp(box, 0, 1)
|
382 |
+
# return box
|
383 |
+
|
384 |
+
|
385 |
+
def get_boxlist_from_centroid_and_size(cy, cx, h, w):#, clip=False):
|
386 |
+
# cy,cx are the same shape
|
387 |
+
ymin = cy - h/2
|
388 |
+
ymax = cy + h/2
|
389 |
+
xmin = cx - w/2
|
390 |
+
xmax = cx + w/2
|
391 |
+
|
392 |
+
# if clip:
|
393 |
+
# ymin = torch.clamp(ymin, 0, H-1)
|
394 |
+
# ymax = torch.clamp(ymax, 0, H-1)
|
395 |
+
# xmin = torch.clamp(xmin, 0, W-1)
|
396 |
+
# xmax = torch.clamp(xmax, 0, W-1)
|
397 |
+
|
398 |
+
box = torch.stack([ymin, xmin, ymax, xmax], dim=-1)
|
399 |
+
return box
|
400 |
+
|
401 |
+
|
402 |
+
def get_box2d_from_mask(mask, normalize=False):
|
403 |
+
# mask is B, 1, H, W
|
404 |
+
|
405 |
+
B, C, H, W = mask.shape
|
406 |
+
assert(C==1)
|
407 |
+
xy = utils.basic.gridcloud2d(B, H, W, norm=False, device=mask.device) # B, H*W, 2
|
408 |
+
|
409 |
+
box = torch.zeros((B, 4), dtype=torch.float32, device=mask.device)
|
410 |
+
for b in range(B):
|
411 |
+
xy_b = xy[b] # H*W, 2
|
412 |
+
mask_b = mask[b].reshape(H*W)
|
413 |
+
xy_ = xy_b[mask_b > 0]
|
414 |
+
x_ = xy_[:,0]
|
415 |
+
y_ = xy_[:,1]
|
416 |
+
ymin = torch.min(y_)
|
417 |
+
ymax = torch.max(y_)
|
418 |
+
xmin = torch.min(x_)
|
419 |
+
xmax = torch.max(x_)
|
420 |
+
box[b] = torch.stack([ymin, xmin, ymax, xmax], dim=0)
|
421 |
+
if normalize:
|
422 |
+
box = normalize_boxlist2d(box.unsqueeze(1), H, W).squeeze(1)
|
423 |
+
return box
|
424 |
+
|
425 |
+
def convert_box2d_to_intrinsics(box2d, pix_T_cam, H, W, use_image_aspect_ratio=True, mult_padding=1.0):
|
426 |
+
# box2d is B x 4, with ymin, xmin, ymax, xmax in normalized coords
|
427 |
+
# ymin, xmin, ymax, xmax = torch.unbind(box2d, dim=1)
|
428 |
+
# H, W is the original size of the image
|
429 |
+
# mult_padding is relative to object size in pixels
|
430 |
+
|
431 |
+
# i assume we're rendering an image the same size as the original (H, W)
|
432 |
+
|
433 |
+
if not mult_padding==1.0:
|
434 |
+
y, x = get_centroid_from_box2d(box2d)
|
435 |
+
h, w = get_size_from_box2d(box2d)
|
436 |
+
box2d = get_box2d_from_centroid_and_size(
|
437 |
+
y, x, h*mult_padding, w*mult_padding, clip=False)
|
438 |
+
|
439 |
+
if use_image_aspect_ratio:
|
440 |
+
h, w = get_size_from_box2d(box2d)
|
441 |
+
y, x = get_centroid_from_box2d(box2d)
|
442 |
+
|
443 |
+
# note h,w are relative right now
|
444 |
+
# we need to undo this, to see the real ratio
|
445 |
+
|
446 |
+
h = h*float(H)
|
447 |
+
w = w*float(W)
|
448 |
+
box_ratio = h/w
|
449 |
+
im_ratio = H/float(W)
|
450 |
+
|
451 |
+
# print('box_ratio:', box_ratio)
|
452 |
+
# print('im_ratio:', im_ratio)
|
453 |
+
|
454 |
+
if box_ratio >= im_ratio:
|
455 |
+
w = h/im_ratio
|
456 |
+
# print('setting w:', h/im_ratio)
|
457 |
+
else:
|
458 |
+
h = w*im_ratio
|
459 |
+
# print('setting h:', w*im_ratio)
|
460 |
+
|
461 |
+
box2d = get_box2d_from_centroid_and_size(
|
462 |
+
y, x, h/float(H), w/float(W), clip=False)
|
463 |
+
|
464 |
+
assert(h > 1e-4)
|
465 |
+
assert(w > 1e-4)
|
466 |
+
|
467 |
+
ymin, xmin, ymax, xmax = torch.unbind(box2d, dim=1)
|
468 |
+
|
469 |
+
fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
|
470 |
+
|
471 |
+
# the topleft of the new image will now have a different offset from the center of projection
|
472 |
+
|
473 |
+
new_x0 = x0 - xmin*W
|
474 |
+
new_y0 = y0 - ymin*H
|
475 |
+
|
476 |
+
pix_T_cam = pack_intrinsics(fx, fy, new_x0, new_y0)
|
477 |
+
# this alone will give me an image in original resolution,
|
478 |
+
# with its topleft at the box corner
|
479 |
+
|
480 |
+
box_h, box_w = get_size_from_box2d(box2d)
|
481 |
+
# these are normalized, and shaped B. (e.g., [0.4], [0.3])
|
482 |
+
|
483 |
+
# we are going to scale the image by the inverse of this,
|
484 |
+
# since we are zooming into this area
|
485 |
+
|
486 |
+
sy = 1./box_h
|
487 |
+
sx = 1./box_w
|
488 |
+
|
489 |
+
pix_T_cam = scale_intrinsics(pix_T_cam, sx, sy)
|
490 |
+
return pix_T_cam, box2d
|
491 |
+
|
492 |
+
def pixels2camera(x,y,z,fx,fy,x0,y0):
|
493 |
+
# x and y are locations in pixel coordinates, z is a depth in meters
|
494 |
+
# they can be images or pointclouds
|
495 |
+
# fx, fy, x0, y0 are camera intrinsics
|
496 |
+
# returns xyz, sized B x N x 3
|
497 |
+
|
498 |
+
B = x.shape[0]
|
499 |
+
|
500 |
+
fx = torch.reshape(fx, [B,1])
|
501 |
+
fy = torch.reshape(fy, [B,1])
|
502 |
+
x0 = torch.reshape(x0, [B,1])
|
503 |
+
y0 = torch.reshape(y0, [B,1])
|
504 |
+
|
505 |
+
x = torch.reshape(x, [B,-1])
|
506 |
+
y = torch.reshape(y, [B,-1])
|
507 |
+
z = torch.reshape(z, [B,-1])
|
508 |
+
|
509 |
+
# unproject
|
510 |
+
x = (z/fx)*(x-x0)
|
511 |
+
y = (z/fy)*(y-y0)
|
512 |
+
|
513 |
+
xyz = torch.stack([x,y,z], dim=2)
|
514 |
+
# B x N x 3
|
515 |
+
return xyz
|
516 |
+
|
517 |
+
def camera2pixels(xyz, pix_T_cam):
|
518 |
+
# xyz is shaped B x H*W x 3
|
519 |
+
# returns xy, shaped B x H*W x 2
|
520 |
+
|
521 |
+
fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
|
522 |
+
x, y, z = torch.unbind(xyz, dim=-1)
|
523 |
+
B = list(z.shape)[0]
|
524 |
+
|
525 |
+
fx = torch.reshape(fx, [B,1])
|
526 |
+
fy = torch.reshape(fy, [B,1])
|
527 |
+
x0 = torch.reshape(x0, [B,1])
|
528 |
+
y0 = torch.reshape(y0, [B,1])
|
529 |
+
x = torch.reshape(x, [B,-1])
|
530 |
+
y = torch.reshape(y, [B,-1])
|
531 |
+
z = torch.reshape(z, [B,-1])
|
532 |
+
|
533 |
+
EPS = 1e-4
|
534 |
+
z = torch.clamp(z, min=EPS)
|
535 |
+
x = (x*fx)/z + x0
|
536 |
+
y = (y*fy)/z + y0
|
537 |
+
xy = torch.stack([x, y], dim=-1)
|
538 |
+
return xy
|
539 |
+
|
540 |
+
def depth2pointcloud(z, pix_T_cam):
|
541 |
+
B, C, H, W = list(z.shape)
|
542 |
+
device = z.device
|
543 |
+
y, x = utils.basic.meshgrid2d(B, H, W, device=device)
|
544 |
+
z = torch.reshape(z, [B, H, W])
|
545 |
+
fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
|
546 |
+
xyz = pixels2camera(x, y, z, fx, fy, x0, y0)
|
547 |
+
return xyz
|
models/spatracker/utils/improc.py
ADDED
@@ -0,0 +1,1447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import models.spatracker.utils.basic
|
4 |
+
from sklearn.decomposition import PCA
|
5 |
+
from matplotlib import cm
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import cv2
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchvision
|
10 |
+
EPS = 1e-6
|
11 |
+
|
12 |
+
from skimage.color import (
|
13 |
+
rgb2lab, rgb2yuv, rgb2ycbcr, lab2rgb, yuv2rgb, ycbcr2rgb,
|
14 |
+
rgb2hsv, hsv2rgb, rgb2xyz, xyz2rgb, rgb2hed, hed2rgb)
|
15 |
+
|
16 |
+
def _convert(input_, type_):
|
17 |
+
return {
|
18 |
+
'float': input_.float(),
|
19 |
+
'double': input_.double(),
|
20 |
+
}.get(type_, input_)
|
21 |
+
|
22 |
+
def _generic_transform_sk_3d(transform, in_type='', out_type=''):
|
23 |
+
def apply_transform_individual(input_):
|
24 |
+
device = input_.device
|
25 |
+
input_ = input_.cpu()
|
26 |
+
input_ = _convert(input_, in_type)
|
27 |
+
|
28 |
+
input_ = input_.permute(1, 2, 0).detach().numpy()
|
29 |
+
transformed = transform(input_)
|
30 |
+
output = torch.from_numpy(transformed).float().permute(2, 0, 1)
|
31 |
+
output = _convert(output, out_type)
|
32 |
+
return output.to(device)
|
33 |
+
|
34 |
+
def apply_transform(input_):
|
35 |
+
to_stack = []
|
36 |
+
for image in input_:
|
37 |
+
to_stack.append(apply_transform_individual(image))
|
38 |
+
return torch.stack(to_stack)
|
39 |
+
return apply_transform
|
40 |
+
|
41 |
+
hsv_to_rgb = _generic_transform_sk_3d(hsv2rgb)
|
42 |
+
|
43 |
+
def preprocess_color_tf(x):
|
44 |
+
import tensorflow as tf
|
45 |
+
return tf.cast(x,tf.float32) * 1./255 - 0.5
|
46 |
+
|
47 |
+
def preprocess_color(x):
|
48 |
+
if isinstance(x, np.ndarray):
|
49 |
+
return x.astype(np.float32) * 1./255 - 0.5
|
50 |
+
else:
|
51 |
+
return x.float() * 1./255 - 0.5
|
52 |
+
|
53 |
+
def pca_embed(emb, keep, valid=None):
|
54 |
+
## emb -- [S,H/2,W/2,C]
|
55 |
+
## keep is the number of principal components to keep
|
56 |
+
## Helper function for reduce_emb.
|
57 |
+
emb = emb + EPS
|
58 |
+
#emb is B x C x H x W
|
59 |
+
emb = emb.permute(0, 2, 3, 1).cpu().detach().numpy() #this is B x H x W x C
|
60 |
+
|
61 |
+
if valid:
|
62 |
+
valid = valid.cpu().detach().numpy().reshape((H*W))
|
63 |
+
|
64 |
+
emb_reduced = list()
|
65 |
+
|
66 |
+
B, H, W, C = np.shape(emb)
|
67 |
+
for img in emb:
|
68 |
+
if np.isnan(img).any():
|
69 |
+
emb_reduced.append(np.zeros([H, W, keep]))
|
70 |
+
continue
|
71 |
+
|
72 |
+
pixels_kd = np.reshape(img, (H*W, C))
|
73 |
+
|
74 |
+
if valid:
|
75 |
+
pixels_kd_pca = pixels_kd[valid]
|
76 |
+
else:
|
77 |
+
pixels_kd_pca = pixels_kd
|
78 |
+
|
79 |
+
P = PCA(keep)
|
80 |
+
P.fit(pixels_kd_pca)
|
81 |
+
|
82 |
+
if valid:
|
83 |
+
pixels3d = P.transform(pixels_kd)*valid
|
84 |
+
else:
|
85 |
+
pixels3d = P.transform(pixels_kd)
|
86 |
+
|
87 |
+
out_img = np.reshape(pixels3d, [H,W,keep]).astype(np.float32)
|
88 |
+
if np.isnan(out_img).any():
|
89 |
+
emb_reduced.append(np.zeros([H, W, keep]))
|
90 |
+
continue
|
91 |
+
|
92 |
+
emb_reduced.append(out_img)
|
93 |
+
|
94 |
+
emb_reduced = np.stack(emb_reduced, axis=0).astype(np.float32)
|
95 |
+
|
96 |
+
return torch.from_numpy(emb_reduced).permute(0, 3, 1, 2)
|
97 |
+
|
98 |
+
def pca_embed_together(emb, keep):
|
99 |
+
## emb -- [S,H/2,W/2,C]
|
100 |
+
## keep is the number of principal components to keep
|
101 |
+
## Helper function for reduce_emb.
|
102 |
+
emb = emb + EPS
|
103 |
+
#emb is B x C x H x W
|
104 |
+
emb = emb.permute(0, 2, 3, 1).cpu().detach().numpy() #this is B x H x W x C
|
105 |
+
|
106 |
+
B, H, W, C = np.shape(emb)
|
107 |
+
if np.isnan(emb).any():
|
108 |
+
return torch.zeros(B, keep, H, W)
|
109 |
+
|
110 |
+
pixelskd = np.reshape(emb, (B*H*W, C))
|
111 |
+
P = PCA(keep)
|
112 |
+
P.fit(pixelskd)
|
113 |
+
pixels3d = P.transform(pixelskd)
|
114 |
+
out_img = np.reshape(pixels3d, [B,H,W,keep]).astype(np.float32)
|
115 |
+
|
116 |
+
if np.isnan(out_img).any():
|
117 |
+
return torch.zeros(B, keep, H, W)
|
118 |
+
|
119 |
+
return torch.from_numpy(out_img).permute(0, 3, 1, 2)
|
120 |
+
|
121 |
+
def reduce_emb(emb, valid=None, inbound=None, together=False):
|
122 |
+
## emb -- [S,C,H/2,W/2], inbound -- [S,1,H/2,W/2]
|
123 |
+
## Reduce number of chans to 3 with PCA. For vis.
|
124 |
+
# S,H,W,C = emb.shape.as_list()
|
125 |
+
S, C, H, W = list(emb.size())
|
126 |
+
keep = 3
|
127 |
+
|
128 |
+
if together:
|
129 |
+
reduced_emb = pca_embed_together(emb, keep)
|
130 |
+
else:
|
131 |
+
reduced_emb = pca_embed(emb, keep, valid) #not im
|
132 |
+
|
133 |
+
reduced_emb = utils.basic.normalize(reduced_emb) - 0.5
|
134 |
+
if inbound is not None:
|
135 |
+
emb_inbound = emb*inbound
|
136 |
+
else:
|
137 |
+
emb_inbound = None
|
138 |
+
|
139 |
+
return reduced_emb, emb_inbound
|
140 |
+
|
141 |
+
def get_feat_pca(feat, valid=None):
|
142 |
+
B, C, D, W = list(feat.size())
|
143 |
+
# feat is B x C x D x W. If 3D input, average it through Height dimension before passing into this function.
|
144 |
+
|
145 |
+
pca, _ = reduce_emb(feat, valid=valid,inbound=None, together=True)
|
146 |
+
# pca is B x 3 x W x D
|
147 |
+
return pca
|
148 |
+
|
149 |
+
def gif_and_tile(ims, just_gif=False):
|
150 |
+
S = len(ims)
|
151 |
+
# each im is B x H x W x C
|
152 |
+
# i want a gif in the left, and the tiled frames on the right
|
153 |
+
# for the gif tool, this means making a B x S x H x W tensor
|
154 |
+
# where the leftmost part is sequential and the rest is tiled
|
155 |
+
gif = torch.stack(ims, dim=1)
|
156 |
+
if just_gif:
|
157 |
+
return gif
|
158 |
+
til = torch.cat(ims, dim=2)
|
159 |
+
til = til.unsqueeze(dim=1).repeat(1, S, 1, 1, 1)
|
160 |
+
im = torch.cat([gif, til], dim=3)
|
161 |
+
return im
|
162 |
+
|
163 |
+
def back2color(i, blacken_zeros=False):
|
164 |
+
if blacken_zeros:
|
165 |
+
const = torch.tensor([-0.5])
|
166 |
+
i = torch.where(i==0.0, const.cuda() if i.is_cuda else const, i)
|
167 |
+
return back2color(i)
|
168 |
+
else:
|
169 |
+
return ((i+0.5)*255).type(torch.ByteTensor)
|
170 |
+
|
171 |
+
def convert_occ_to_height(occ, reduce_axis=3):
|
172 |
+
B, C, D, H, W = list(occ.shape)
|
173 |
+
assert(C==1)
|
174 |
+
# note that height increases DOWNWARD in the tensor
|
175 |
+
# (like pixel/camera coordinates)
|
176 |
+
|
177 |
+
G = list(occ.shape)[reduce_axis]
|
178 |
+
values = torch.linspace(float(G), 1.0, steps=G, dtype=torch.float32, device=occ.device)
|
179 |
+
if reduce_axis==2:
|
180 |
+
# fro view
|
181 |
+
values = values.view(1, 1, G, 1, 1)
|
182 |
+
elif reduce_axis==3:
|
183 |
+
# top view
|
184 |
+
values = values.view(1, 1, 1, G, 1)
|
185 |
+
elif reduce_axis==4:
|
186 |
+
# lateral view
|
187 |
+
values = values.view(1, 1, 1, 1, G)
|
188 |
+
else:
|
189 |
+
assert(False) # you have to reduce one of the spatial dims (2-4)
|
190 |
+
values = torch.max(occ*values, dim=reduce_axis)[0]/float(G)
|
191 |
+
# values = values.view([B, C, D, W])
|
192 |
+
return values
|
193 |
+
|
194 |
+
def xy2heatmap(xy, sigma, grid_xs, grid_ys, norm=False):
|
195 |
+
# xy is B x N x 2, containing float x and y coordinates of N things
|
196 |
+
# grid_xs and grid_ys are B x N x Y x X
|
197 |
+
|
198 |
+
B, N, Y, X = list(grid_xs.shape)
|
199 |
+
|
200 |
+
mu_x = xy[:,:,0].clone()
|
201 |
+
mu_y = xy[:,:,1].clone()
|
202 |
+
|
203 |
+
x_valid = (mu_x>-0.5) & (mu_x<float(X+0.5))
|
204 |
+
y_valid = (mu_y>-0.5) & (mu_y<float(Y+0.5))
|
205 |
+
not_valid = ~(x_valid & y_valid)
|
206 |
+
|
207 |
+
mu_x[not_valid] = -10000
|
208 |
+
mu_y[not_valid] = -10000
|
209 |
+
|
210 |
+
mu_x = mu_x.reshape(B, N, 1, 1).repeat(1, 1, Y, X)
|
211 |
+
mu_y = mu_y.reshape(B, N, 1, 1).repeat(1, 1, Y, X)
|
212 |
+
|
213 |
+
sigma_sq = sigma*sigma
|
214 |
+
# sigma_sq = (sigma*sigma).reshape(B, N, 1, 1)
|
215 |
+
sq_diff_x = (grid_xs - mu_x)**2
|
216 |
+
sq_diff_y = (grid_ys - mu_y)**2
|
217 |
+
|
218 |
+
term1 = 1./2.*np.pi*sigma_sq
|
219 |
+
term2 = torch.exp(-(sq_diff_x+sq_diff_y)/(2.*sigma_sq))
|
220 |
+
gauss = term1*term2
|
221 |
+
|
222 |
+
if norm:
|
223 |
+
# normalize so each gaussian peaks at 1
|
224 |
+
gauss_ = gauss.reshape(B*N, Y, X)
|
225 |
+
gauss_ = utils.basic.normalize(gauss_)
|
226 |
+
gauss = gauss_.reshape(B, N, Y, X)
|
227 |
+
|
228 |
+
return gauss
|
229 |
+
|
230 |
+
def xy2heatmaps(xy, Y, X, sigma=30.0, norm=True):
|
231 |
+
# xy is B x N x 2
|
232 |
+
|
233 |
+
B, N, D = list(xy.shape)
|
234 |
+
assert(D==2)
|
235 |
+
|
236 |
+
device = xy.device
|
237 |
+
|
238 |
+
grid_y, grid_x = utils.basic.meshgrid2d(B, Y, X, device=device)
|
239 |
+
# grid_x and grid_y are B x Y x X
|
240 |
+
grid_xs = grid_x.unsqueeze(1).repeat(1, N, 1, 1)
|
241 |
+
grid_ys = grid_y.unsqueeze(1).repeat(1, N, 1, 1)
|
242 |
+
heat = xy2heatmap(xy, sigma, grid_xs, grid_ys, norm=norm)
|
243 |
+
return heat
|
244 |
+
|
245 |
+
def draw_circles_at_xy(xy, Y, X, sigma=12.5, round=False):
|
246 |
+
B, N, D = list(xy.shape)
|
247 |
+
assert(D==2)
|
248 |
+
prior = xy2heatmaps(xy, Y, X, sigma=sigma)
|
249 |
+
# prior is B x N x Y x X
|
250 |
+
if round:
|
251 |
+
prior = (prior > 0.5).float()
|
252 |
+
return prior
|
253 |
+
|
254 |
+
def seq2color(im, norm=True, colormap='coolwarm'):
|
255 |
+
B, S, H, W = list(im.shape)
|
256 |
+
# S is sequential
|
257 |
+
|
258 |
+
# prep a mask of the valid pixels, so we can blacken the invalids later
|
259 |
+
mask = torch.max(im, dim=1, keepdim=True)[0]
|
260 |
+
|
261 |
+
# turn the S dim into an explicit sequence
|
262 |
+
coeffs = np.linspace(1.0, float(S), S).astype(np.float32)/float(S)
|
263 |
+
|
264 |
+
# # increase the spacing from the center
|
265 |
+
# coeffs[:int(S/2)] -= 2.0
|
266 |
+
# coeffs[int(S/2)+1:] += 2.0
|
267 |
+
|
268 |
+
coeffs = torch.from_numpy(coeffs).float().cuda()
|
269 |
+
coeffs = coeffs.reshape(1, S, 1, 1).repeat(B, 1, H, W)
|
270 |
+
# scale each channel by the right coeff
|
271 |
+
im = im * coeffs
|
272 |
+
# now im is in [1/S, 1], except for the invalid parts which are 0
|
273 |
+
# keep the highest valid coeff at each pixel
|
274 |
+
im = torch.max(im, dim=1, keepdim=True)[0]
|
275 |
+
|
276 |
+
out = []
|
277 |
+
for b in range(B):
|
278 |
+
im_ = im[b]
|
279 |
+
# move channels out to last dim_
|
280 |
+
im_ = im_.detach().cpu().numpy()
|
281 |
+
im_ = np.squeeze(im_)
|
282 |
+
# im_ is H x W
|
283 |
+
if colormap=='coolwarm':
|
284 |
+
im_ = cm.coolwarm(im_)[:, :, :3]
|
285 |
+
elif colormap=='PiYG':
|
286 |
+
im_ = cm.PiYG(im_)[:, :, :3]
|
287 |
+
elif colormap=='winter':
|
288 |
+
im_ = cm.winter(im_)[:, :, :3]
|
289 |
+
elif colormap=='spring':
|
290 |
+
im_ = cm.spring(im_)[:, :, :3]
|
291 |
+
elif colormap=='onediff':
|
292 |
+
im_ = np.reshape(im_, (-1))
|
293 |
+
im0_ = cm.spring(im_)[:, :3]
|
294 |
+
im1_ = cm.winter(im_)[:, :3]
|
295 |
+
im1_[im_==1/float(S)] = im0_[im_==1/float(S)]
|
296 |
+
im_ = np.reshape(im1_, (H, W, 3))
|
297 |
+
else:
|
298 |
+
assert(False) # invalid colormap
|
299 |
+
# move channels into dim 0
|
300 |
+
im_ = np.transpose(im_, [2, 0, 1])
|
301 |
+
im_ = torch.from_numpy(im_).float().cuda()
|
302 |
+
out.append(im_)
|
303 |
+
out = torch.stack(out, dim=0)
|
304 |
+
|
305 |
+
# blacken the invalid pixels, instead of using the 0-color
|
306 |
+
out = out*mask
|
307 |
+
# out = out*255.0
|
308 |
+
|
309 |
+
# put it in [-0.5, 0.5]
|
310 |
+
out = out - 0.5
|
311 |
+
|
312 |
+
return out
|
313 |
+
|
314 |
+
def colorize(d):
|
315 |
+
# this is actually just grayscale right now
|
316 |
+
|
317 |
+
if d.ndim==2:
|
318 |
+
d = d.unsqueeze(dim=0)
|
319 |
+
else:
|
320 |
+
assert(d.ndim==3)
|
321 |
+
|
322 |
+
# color_map = cm.get_cmap('plasma')
|
323 |
+
color_map = cm.get_cmap('inferno')
|
324 |
+
# S1, D = traj.shape
|
325 |
+
|
326 |
+
# print('d1', d.shape)
|
327 |
+
C,H,W = d.shape
|
328 |
+
assert(C==1)
|
329 |
+
d = d.reshape(-1)
|
330 |
+
d = d.detach().cpu().numpy()
|
331 |
+
# print('d2', d.shape)
|
332 |
+
color = np.array(color_map(d)) * 255 # rgba
|
333 |
+
# print('color1', color.shape)
|
334 |
+
color = np.reshape(color[:,:3], [H*W, 3])
|
335 |
+
# print('color2', color.shape)
|
336 |
+
color = torch.from_numpy(color).permute(1,0).reshape(3,H,W)
|
337 |
+
# # gather
|
338 |
+
# cm = matplotlib.cm.get_cmap(cmap if cmap is not None else 'gray')
|
339 |
+
# if cmap=='RdBu' or cmap=='RdYlGn':
|
340 |
+
# colors = cm(np.arange(256))[:, :3]
|
341 |
+
# else:
|
342 |
+
# colors = cm.colors
|
343 |
+
# colors = np.array(colors).astype(np.float32)
|
344 |
+
# colors = np.reshape(colors, [-1, 3])
|
345 |
+
# colors = tf.constant(colors, dtype=tf.float32)
|
346 |
+
|
347 |
+
# value = tf.gather(colors, indices)
|
348 |
+
# colorize(value, normalize=True, vmin=None, vmax=None, cmap=None, vals=255)
|
349 |
+
|
350 |
+
# copy to the three chans
|
351 |
+
# d = d.repeat(3, 1, 1)
|
352 |
+
return color
|
353 |
+
|
354 |
+
|
355 |
+
def oned2inferno(d, norm=True, do_colorize=False):
|
356 |
+
# convert a 1chan input to a 3chan image output
|
357 |
+
|
358 |
+
# if it's just B x H x W, add a C dim
|
359 |
+
if d.ndim==3:
|
360 |
+
d = d.unsqueeze(dim=1)
|
361 |
+
# d should be B x C x H x W, where C=1
|
362 |
+
B, C, H, W = list(d.shape)
|
363 |
+
assert(C==1)
|
364 |
+
|
365 |
+
if norm:
|
366 |
+
d = utils.basic.normalize(d)
|
367 |
+
|
368 |
+
if do_colorize:
|
369 |
+
rgb = torch.zeros(B, 3, H, W)
|
370 |
+
for b in list(range(B)):
|
371 |
+
rgb[b] = colorize(d[b])
|
372 |
+
else:
|
373 |
+
rgb = d.repeat(1, 3, 1, 1)*255.0
|
374 |
+
# rgb = (255.0*rgb).type(torch.ByteTensor)
|
375 |
+
rgb = rgb.type(torch.ByteTensor)
|
376 |
+
|
377 |
+
# rgb = tf.cast(255.0*rgb, tf.uint8)
|
378 |
+
# rgb = tf.reshape(rgb, [-1, hyp.H, hyp.W, 3])
|
379 |
+
# rgb = tf.expand_dims(rgb, axis=0)
|
380 |
+
return rgb
|
381 |
+
|
382 |
+
def oned2gray(d, norm=True):
|
383 |
+
# convert a 1chan input to a 3chan image output
|
384 |
+
|
385 |
+
# if it's just B x H x W, add a C dim
|
386 |
+
if d.ndim==3:
|
387 |
+
d = d.unsqueeze(dim=1)
|
388 |
+
# d should be B x C x H x W, where C=1
|
389 |
+
B, C, H, W = list(d.shape)
|
390 |
+
assert(C==1)
|
391 |
+
|
392 |
+
if norm:
|
393 |
+
d = utils.basic.normalize(d)
|
394 |
+
|
395 |
+
rgb = d.repeat(1,3,1,1)
|
396 |
+
rgb = (255.0*rgb).type(torch.ByteTensor)
|
397 |
+
return rgb
|
398 |
+
|
399 |
+
|
400 |
+
def draw_frame_id_on_vis(vis, frame_id, scale=0.5, left=5, top=20):
|
401 |
+
|
402 |
+
rgb = vis.detach().cpu().numpy()[0]
|
403 |
+
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
|
404 |
+
rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
|
405 |
+
color = (255, 255, 255)
|
406 |
+
# print('putting frame id', frame_id)
|
407 |
+
|
408 |
+
frame_str = utils.basic.strnum(frame_id)
|
409 |
+
|
410 |
+
text_color_bg = (0,0,0)
|
411 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
412 |
+
text_size, _ = cv2.getTextSize(frame_str, font, scale, 1)
|
413 |
+
text_w, text_h = text_size
|
414 |
+
cv2.rectangle(rgb, (left, top-text_h), (left + text_w, top+1), text_color_bg, -1)
|
415 |
+
|
416 |
+
cv2.putText(
|
417 |
+
rgb,
|
418 |
+
frame_str,
|
419 |
+
(left, top), # from left, from top
|
420 |
+
font,
|
421 |
+
scale, # font scale (float)
|
422 |
+
color,
|
423 |
+
1) # font thickness (int)
|
424 |
+
rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB)
|
425 |
+
vis = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
|
426 |
+
return vis
|
427 |
+
|
428 |
+
COLORMAP_FILE = "./utils/bremm.png"
|
429 |
+
class ColorMap2d:
|
430 |
+
def __init__(self, filename=None):
|
431 |
+
self._colormap_file = filename or COLORMAP_FILE
|
432 |
+
self._img = plt.imread(self._colormap_file)
|
433 |
+
|
434 |
+
self._height = self._img.shape[0]
|
435 |
+
self._width = self._img.shape[1]
|
436 |
+
|
437 |
+
def __call__(self, X):
|
438 |
+
assert len(X.shape) == 2
|
439 |
+
output = np.zeros((X.shape[0], 3))
|
440 |
+
for i in range(X.shape[0]):
|
441 |
+
x, y = X[i, :]
|
442 |
+
xp = int((self._width-1) * x)
|
443 |
+
yp = int((self._height-1) * y)
|
444 |
+
xp = np.clip(xp, 0, self._width-1)
|
445 |
+
yp = np.clip(yp, 0, self._height-1)
|
446 |
+
output[i, :] = self._img[yp, xp]
|
447 |
+
return output
|
448 |
+
|
449 |
+
def get_n_colors(N, sequential=False):
|
450 |
+
label_colors = []
|
451 |
+
for ii in range(N):
|
452 |
+
if sequential:
|
453 |
+
rgb = cm.winter(ii/(N-1))
|
454 |
+
rgb = (np.array(rgb) * 255).astype(np.uint8)[:3]
|
455 |
+
else:
|
456 |
+
rgb = np.zeros(3)
|
457 |
+
while np.sum(rgb) < 128: # ensure min brightness
|
458 |
+
rgb = np.random.randint(0,256,3)
|
459 |
+
label_colors.append(rgb)
|
460 |
+
return label_colors
|
461 |
+
|
462 |
+
class Summ_writer(object):
|
463 |
+
def __init__(self, writer, global_step, log_freq=10, fps=8, scalar_freq=100, just_gif=False):
|
464 |
+
self.writer = writer
|
465 |
+
self.global_step = global_step
|
466 |
+
self.log_freq = log_freq
|
467 |
+
self.fps = fps
|
468 |
+
self.just_gif = just_gif
|
469 |
+
self.maxwidth = 10000
|
470 |
+
self.save_this = (self.global_step % self.log_freq == 0)
|
471 |
+
self.scalar_freq = max(scalar_freq,1)
|
472 |
+
|
473 |
+
|
474 |
+
def summ_gif(self, name, tensor, blacken_zeros=False):
|
475 |
+
# tensor should be in B x S x C x H x W
|
476 |
+
|
477 |
+
assert tensor.dtype in {torch.uint8,torch.float32}
|
478 |
+
shape = list(tensor.shape)
|
479 |
+
|
480 |
+
if tensor.dtype == torch.float32:
|
481 |
+
tensor = back2color(tensor, blacken_zeros=blacken_zeros)
|
482 |
+
|
483 |
+
video_to_write = tensor[0:1]
|
484 |
+
|
485 |
+
S = video_to_write.shape[1]
|
486 |
+
if S==1:
|
487 |
+
# video_to_write is 1 x 1 x C x H x W
|
488 |
+
self.writer.add_image(name, video_to_write[0,0], global_step=self.global_step)
|
489 |
+
else:
|
490 |
+
self.writer.add_video(name, video_to_write, fps=self.fps, global_step=self.global_step)
|
491 |
+
|
492 |
+
return video_to_write
|
493 |
+
|
494 |
+
def draw_boxlist2d_on_image(self, rgb, boxlist, scores=None, tids=None, linewidth=1):
|
495 |
+
B, C, H, W = list(rgb.shape)
|
496 |
+
assert(C==3)
|
497 |
+
B2, N, D = list(boxlist.shape)
|
498 |
+
assert(B2==B)
|
499 |
+
assert(D==4) # ymin, xmin, ymax, xmax
|
500 |
+
|
501 |
+
rgb = back2color(rgb)
|
502 |
+
if scores is None:
|
503 |
+
scores = torch.ones(B2, N).float()
|
504 |
+
if tids is None:
|
505 |
+
tids = torch.arange(N).reshape(1,N).repeat(B2,N).long()
|
506 |
+
# tids = torch.zeros(B2, N).long()
|
507 |
+
out = self.draw_boxlist2d_on_image_py(
|
508 |
+
rgb[0].cpu().detach().numpy(),
|
509 |
+
boxlist[0].cpu().detach().numpy(),
|
510 |
+
scores[0].cpu().detach().numpy(),
|
511 |
+
tids[0].cpu().detach().numpy(),
|
512 |
+
linewidth=linewidth)
|
513 |
+
out = torch.from_numpy(out).type(torch.ByteTensor).permute(2, 0, 1)
|
514 |
+
out = torch.unsqueeze(out, dim=0)
|
515 |
+
out = preprocess_color(out)
|
516 |
+
out = torch.reshape(out, [1, C, H, W])
|
517 |
+
return out
|
518 |
+
|
519 |
+
def draw_boxlist2d_on_image_py(self, rgb, boxlist, scores, tids, linewidth=1):
|
520 |
+
# all inputs are numpy tensors
|
521 |
+
# rgb is H x W x 3
|
522 |
+
# boxlist is N x 4
|
523 |
+
# scores is N
|
524 |
+
# tids is N
|
525 |
+
|
526 |
+
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
|
527 |
+
# rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
|
528 |
+
|
529 |
+
rgb = rgb.astype(np.uint8).copy()
|
530 |
+
|
531 |
+
|
532 |
+
H, W, C = rgb.shape
|
533 |
+
assert(C==3)
|
534 |
+
N, D = boxlist.shape
|
535 |
+
assert(D==4)
|
536 |
+
|
537 |
+
# color_map = cm.get_cmap('tab20')
|
538 |
+
# color_map = cm.get_cmap('set1')
|
539 |
+
color_map = cm.get_cmap('Accent')
|
540 |
+
color_map = color_map.colors
|
541 |
+
# print('color_map', color_map)
|
542 |
+
|
543 |
+
# draw
|
544 |
+
for ind, box in enumerate(boxlist):
|
545 |
+
# box is 4
|
546 |
+
if not np.isclose(scores[ind], 0.0):
|
547 |
+
# box = utils.geom.scale_box2d(box, H, W)
|
548 |
+
ymin, xmin, ymax, xmax = box
|
549 |
+
|
550 |
+
# ymin, ymax = ymin*H, ymax*H
|
551 |
+
# xmin, xmax = xmin*W, xmax*W
|
552 |
+
|
553 |
+
# print 'score = %.2f' % scores[ind]
|
554 |
+
# color_id = tids[ind] % 20
|
555 |
+
color_id = tids[ind]
|
556 |
+
color = color_map[color_id]
|
557 |
+
color = np.array(color)*255.0
|
558 |
+
color = color.round()
|
559 |
+
# color = color.astype(np.uint8)
|
560 |
+
# color = color[::-1]
|
561 |
+
# print('color', color)
|
562 |
+
|
563 |
+
# print 'tid = %d; score = %.3f' % (tids[ind], scores[ind])
|
564 |
+
|
565 |
+
# if False:
|
566 |
+
if scores[ind] < 1.0: # not gt
|
567 |
+
cv2.putText(rgb,
|
568 |
+
# '%d (%.2f)' % (tids[ind], scores[ind]),
|
569 |
+
'%.2f' % (scores[ind]),
|
570 |
+
(int(xmin), int(ymin)),
|
571 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
572 |
+
0.5, # font size
|
573 |
+
color),
|
574 |
+
#1) # font weight
|
575 |
+
|
576 |
+
xmin = np.clip(int(xmin), 0, W-1)
|
577 |
+
xmax = np.clip(int(xmax), 0, W-1)
|
578 |
+
ymin = np.clip(int(ymin), 0, H-1)
|
579 |
+
ymax = np.clip(int(ymax), 0, H-1)
|
580 |
+
|
581 |
+
cv2.line(rgb, (xmin, ymin), (xmin, ymax), color, linewidth, cv2.LINE_AA)
|
582 |
+
cv2.line(rgb, (xmin, ymin), (xmax, ymin), color, linewidth, cv2.LINE_AA)
|
583 |
+
cv2.line(rgb, (xmax, ymin), (xmax, ymax), color, linewidth, cv2.LINE_AA)
|
584 |
+
cv2.line(rgb, (xmax, ymax), (xmin, ymax), color, linewidth, cv2.LINE_AA)
|
585 |
+
|
586 |
+
# rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB)
|
587 |
+
return rgb
|
588 |
+
|
589 |
+
def summ_boxlist2d(self, name, rgb, boxlist, scores=None, tids=None, frame_id=None, only_return=False, linewidth=2):
|
590 |
+
B, C, H, W = list(rgb.shape)
|
591 |
+
boxlist_vis = self.draw_boxlist2d_on_image(rgb, boxlist, scores=scores, tids=tids, linewidth=linewidth)
|
592 |
+
return self.summ_rgb(name, boxlist_vis, frame_id=frame_id, only_return=only_return)
|
593 |
+
|
594 |
+
def summ_rgbs(self, name, ims, frame_ids=None, blacken_zeros=False, only_return=False):
|
595 |
+
if self.save_this:
|
596 |
+
|
597 |
+
ims = gif_and_tile(ims, just_gif=self.just_gif)
|
598 |
+
vis = ims
|
599 |
+
|
600 |
+
assert vis.dtype in {torch.uint8,torch.float32}
|
601 |
+
|
602 |
+
if vis.dtype == torch.float32:
|
603 |
+
vis = back2color(vis, blacken_zeros)
|
604 |
+
|
605 |
+
B, S, C, H, W = list(vis.shape)
|
606 |
+
|
607 |
+
if frame_ids is not None:
|
608 |
+
assert(len(frame_ids)==S)
|
609 |
+
for s in range(S):
|
610 |
+
vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s])
|
611 |
+
|
612 |
+
if int(W) > self.maxwidth:
|
613 |
+
vis = vis[:,:,:,:self.maxwidth]
|
614 |
+
|
615 |
+
if only_return:
|
616 |
+
return vis
|
617 |
+
else:
|
618 |
+
return self.summ_gif(name, vis, blacken_zeros)
|
619 |
+
|
620 |
+
def summ_rgb(self, name, ims, blacken_zeros=False, frame_id=None, only_return=False, halfres=False):
|
621 |
+
if self.save_this:
|
622 |
+
assert ims.dtype in {torch.uint8,torch.float32}
|
623 |
+
|
624 |
+
if ims.dtype == torch.float32:
|
625 |
+
ims = back2color(ims, blacken_zeros)
|
626 |
+
|
627 |
+
#ims is B x C x H x W
|
628 |
+
vis = ims[0:1] # just the first one
|
629 |
+
B, C, H, W = list(vis.shape)
|
630 |
+
|
631 |
+
if halfres:
|
632 |
+
vis = F.interpolate(vis, scale_factor=0.5)
|
633 |
+
|
634 |
+
if frame_id is not None:
|
635 |
+
vis = draw_frame_id_on_vis(vis, frame_id)
|
636 |
+
|
637 |
+
if int(W) > self.maxwidth:
|
638 |
+
vis = vis[:,:,:,:self.maxwidth]
|
639 |
+
|
640 |
+
if only_return:
|
641 |
+
return vis
|
642 |
+
else:
|
643 |
+
return self.summ_gif(name, vis.unsqueeze(1), blacken_zeros)
|
644 |
+
|
645 |
+
def flow2color(self, flow, clip=50.0):
|
646 |
+
"""
|
647 |
+
:param flow: Optical flow tensor.
|
648 |
+
:return: RGB image normalized between 0 and 1.
|
649 |
+
"""
|
650 |
+
|
651 |
+
# flow is B x C x H x W
|
652 |
+
|
653 |
+
B, C, H, W = list(flow.size())
|
654 |
+
|
655 |
+
flow = flow.clone().detach()
|
656 |
+
|
657 |
+
abs_image = torch.abs(flow)
|
658 |
+
flow_mean = abs_image.mean(dim=[1,2,3])
|
659 |
+
flow_std = abs_image.std(dim=[1,2,3])
|
660 |
+
|
661 |
+
if clip:
|
662 |
+
flow = torch.clamp(flow, -clip, clip)/clip
|
663 |
+
else:
|
664 |
+
# Apply some kind of normalization. Divide by the perceived maximum (mean + std*2)
|
665 |
+
flow_max = flow_mean + flow_std*2 + 1e-10
|
666 |
+
for b in range(B):
|
667 |
+
flow[b] = flow[b].clamp(-flow_max[b].item(), flow_max[b].item()) / flow_max[b].clamp(min=1)
|
668 |
+
|
669 |
+
radius = torch.sqrt(torch.sum(flow**2, dim=1, keepdim=True)) #B x 1 x H x W
|
670 |
+
radius_clipped = torch.clamp(radius, 0.0, 1.0)
|
671 |
+
|
672 |
+
angle = torch.atan2(flow[:, 1:], flow[:, 0:1]) / np.pi #B x 1 x H x W
|
673 |
+
|
674 |
+
hue = torch.clamp((angle + 1.0) / 2.0, 0.0, 1.0)
|
675 |
+
saturation = torch.ones_like(hue) * 0.75
|
676 |
+
value = radius_clipped
|
677 |
+
hsv = torch.cat([hue, saturation, value], dim=1) #B x 3 x H x W
|
678 |
+
|
679 |
+
#flow = tf.image.hsv_to_rgb(hsv)
|
680 |
+
flow = hsv_to_rgb(hsv)
|
681 |
+
flow = (flow*255.0).type(torch.ByteTensor)
|
682 |
+
return flow
|
683 |
+
|
684 |
+
def summ_flow(self, name, im, clip=0.0, only_return=False, frame_id=None):
|
685 |
+
# flow is B x C x D x W
|
686 |
+
if self.save_this:
|
687 |
+
return self.summ_rgb(name, self.flow2color(im, clip=clip), only_return=only_return, frame_id=frame_id)
|
688 |
+
else:
|
689 |
+
return None
|
690 |
+
|
691 |
+
def summ_oneds(self, name, ims, frame_ids=None, bev=False, fro=False, logvis=False, reduce_max=False, max_val=0.0, norm=True, only_return=False, do_colorize=False):
|
692 |
+
if self.save_this:
|
693 |
+
if bev:
|
694 |
+
B, C, H, _, W = list(ims[0].shape)
|
695 |
+
if reduce_max:
|
696 |
+
ims = [torch.max(im, dim=3)[0] for im in ims]
|
697 |
+
else:
|
698 |
+
ims = [torch.mean(im, dim=3) for im in ims]
|
699 |
+
elif fro:
|
700 |
+
B, C, _, H, W = list(ims[0].shape)
|
701 |
+
if reduce_max:
|
702 |
+
ims = [torch.max(im, dim=2)[0] for im in ims]
|
703 |
+
else:
|
704 |
+
ims = [torch.mean(im, dim=2) for im in ims]
|
705 |
+
|
706 |
+
|
707 |
+
if len(ims) != 1: # sequence
|
708 |
+
im = gif_and_tile(ims, just_gif=self.just_gif)
|
709 |
+
else:
|
710 |
+
im = torch.stack(ims, dim=1) # single frame
|
711 |
+
|
712 |
+
B, S, C, H, W = list(im.shape)
|
713 |
+
|
714 |
+
if logvis and max_val:
|
715 |
+
max_val = np.log(max_val)
|
716 |
+
im = torch.log(torch.clamp(im, 0)+1.0)
|
717 |
+
im = torch.clamp(im, 0, max_val)
|
718 |
+
im = im/max_val
|
719 |
+
norm = False
|
720 |
+
elif max_val:
|
721 |
+
im = torch.clamp(im, 0, max_val)
|
722 |
+
im = im/max_val
|
723 |
+
norm = False
|
724 |
+
|
725 |
+
if norm:
|
726 |
+
# normalize before oned2inferno,
|
727 |
+
# so that the ranges are similar within B across S
|
728 |
+
im = utils.basic.normalize(im)
|
729 |
+
|
730 |
+
im = im.view(B*S, C, H, W)
|
731 |
+
vis = oned2inferno(im, norm=norm, do_colorize=do_colorize)
|
732 |
+
vis = vis.view(B, S, 3, H, W)
|
733 |
+
|
734 |
+
if frame_ids is not None:
|
735 |
+
assert(len(frame_ids)==S)
|
736 |
+
for s in range(S):
|
737 |
+
vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s])
|
738 |
+
|
739 |
+
if W > self.maxwidth:
|
740 |
+
vis = vis[...,:self.maxwidth]
|
741 |
+
|
742 |
+
if only_return:
|
743 |
+
return vis
|
744 |
+
else:
|
745 |
+
self.summ_gif(name, vis)
|
746 |
+
|
747 |
+
def summ_oned(self, name, im, bev=False, fro=False, logvis=False, max_val=0, max_along_y=False, norm=True, frame_id=None, only_return=False):
|
748 |
+
if self.save_this:
|
749 |
+
|
750 |
+
if bev:
|
751 |
+
B, C, H, _, W = list(im.shape)
|
752 |
+
if max_along_y:
|
753 |
+
im = torch.max(im, dim=3)[0]
|
754 |
+
else:
|
755 |
+
im = torch.mean(im, dim=3)
|
756 |
+
elif fro:
|
757 |
+
B, C, _, H, W = list(im.shape)
|
758 |
+
if max_along_y:
|
759 |
+
im = torch.max(im, dim=2)[0]
|
760 |
+
else:
|
761 |
+
im = torch.mean(im, dim=2)
|
762 |
+
else:
|
763 |
+
B, C, H, W = list(im.shape)
|
764 |
+
|
765 |
+
im = im[0:1] # just the first one
|
766 |
+
assert(C==1)
|
767 |
+
|
768 |
+
if logvis and max_val:
|
769 |
+
max_val = np.log(max_val)
|
770 |
+
im = torch.log(im)
|
771 |
+
im = torch.clamp(im, 0, max_val)
|
772 |
+
im = im/max_val
|
773 |
+
norm = False
|
774 |
+
elif max_val:
|
775 |
+
im = torch.clamp(im, 0, max_val)/max_val
|
776 |
+
norm = False
|
777 |
+
|
778 |
+
vis = oned2inferno(im, norm=norm)
|
779 |
+
if W > self.maxwidth:
|
780 |
+
vis = vis[...,:self.maxwidth]
|
781 |
+
return self.summ_rgb(name, vis, blacken_zeros=False, frame_id=frame_id, only_return=only_return)
|
782 |
+
|
783 |
+
def summ_feats(self, name, feats, valids=None, pca=True, fro=False, only_return=False, frame_ids=None):
|
784 |
+
if self.save_this:
|
785 |
+
if valids is not None:
|
786 |
+
valids = torch.stack(valids, dim=1)
|
787 |
+
|
788 |
+
feats = torch.stack(feats, dim=1)
|
789 |
+
# feats leads with B x S x C
|
790 |
+
|
791 |
+
if feats.ndim==6:
|
792 |
+
|
793 |
+
# feats is B x S x C x D x H x W
|
794 |
+
if fro:
|
795 |
+
reduce_dim = 3
|
796 |
+
else:
|
797 |
+
reduce_dim = 4
|
798 |
+
|
799 |
+
if valids is None:
|
800 |
+
feats = torch.mean(feats, dim=reduce_dim)
|
801 |
+
else:
|
802 |
+
valids = valids.repeat(1, 1, feats.size()[2], 1, 1, 1)
|
803 |
+
feats = utils.basic.reduce_masked_mean(feats, valids, dim=reduce_dim)
|
804 |
+
|
805 |
+
B, S, C, D, W = list(feats.size())
|
806 |
+
|
807 |
+
if not pca:
|
808 |
+
# feats leads with B x S x C
|
809 |
+
feats = torch.mean(torch.abs(feats), dim=2, keepdims=True)
|
810 |
+
# feats leads with B x S x 1
|
811 |
+
feats = torch.unbind(feats, dim=1)
|
812 |
+
return self.summ_oneds(name=name, ims=feats, norm=True, only_return=only_return, frame_ids=frame_ids)
|
813 |
+
|
814 |
+
else:
|
815 |
+
__p = lambda x: utils.basic.pack_seqdim(x, B)
|
816 |
+
__u = lambda x: utils.basic.unpack_seqdim(x, B)
|
817 |
+
|
818 |
+
feats_ = __p(feats)
|
819 |
+
|
820 |
+
if valids is None:
|
821 |
+
feats_pca_ = get_feat_pca(feats_)
|
822 |
+
else:
|
823 |
+
valids_ = __p(valids)
|
824 |
+
feats_pca_ = get_feat_pca(feats_, valids)
|
825 |
+
|
826 |
+
feats_pca = __u(feats_pca_)
|
827 |
+
|
828 |
+
return self.summ_rgbs(name=name, ims=torch.unbind(feats_pca, dim=1), only_return=only_return, frame_ids=frame_ids)
|
829 |
+
|
830 |
+
def summ_feat(self, name, feat, valid=None, pca=True, only_return=False, bev=False, fro=False, frame_id=None):
|
831 |
+
if self.save_this:
|
832 |
+
if feat.ndim==5: # B x C x D x H x W
|
833 |
+
|
834 |
+
if bev:
|
835 |
+
reduce_axis = 3
|
836 |
+
elif fro:
|
837 |
+
reduce_axis = 2
|
838 |
+
else:
|
839 |
+
# default to bev
|
840 |
+
reduce_axis = 3
|
841 |
+
|
842 |
+
if valid is None:
|
843 |
+
feat = torch.mean(feat, dim=reduce_axis)
|
844 |
+
else:
|
845 |
+
valid = valid.repeat(1, feat.size()[1], 1, 1, 1)
|
846 |
+
feat = utils.basic.reduce_masked_mean(feat, valid, dim=reduce_axis)
|
847 |
+
|
848 |
+
B, C, D, W = list(feat.shape)
|
849 |
+
|
850 |
+
if not pca:
|
851 |
+
feat = torch.mean(torch.abs(feat), dim=1, keepdims=True)
|
852 |
+
# feat is B x 1 x D x W
|
853 |
+
return self.summ_oned(name=name, im=feat, norm=True, only_return=only_return, frame_id=frame_id)
|
854 |
+
else:
|
855 |
+
feat_pca = get_feat_pca(feat, valid)
|
856 |
+
return self.summ_rgb(name, feat_pca, only_return=only_return, frame_id=frame_id)
|
857 |
+
|
858 |
+
def summ_scalar(self, name, value):
|
859 |
+
if (not (isinstance(value, int) or isinstance(value, float) or isinstance(value, np.float32))) and ('torch' in value.type()):
|
860 |
+
value = value.detach().cpu().numpy()
|
861 |
+
if not np.isnan(value):
|
862 |
+
if (self.log_freq == 1):
|
863 |
+
self.writer.add_scalar(name, value, global_step=self.global_step)
|
864 |
+
elif self.save_this or np.mod(self.global_step, self.scalar_freq)==0:
|
865 |
+
self.writer.add_scalar(name, value, global_step=self.global_step)
|
866 |
+
|
867 |
+
def summ_seg(self, name, seg, only_return=False, frame_id=None, colormap='tab20', label_colors=None):
|
868 |
+
if not self.save_this:
|
869 |
+
return
|
870 |
+
|
871 |
+
B,H,W = seg.shape
|
872 |
+
|
873 |
+
if label_colors is None:
|
874 |
+
custom_label_colors = False
|
875 |
+
# label_colors = get_n_colors(int(torch.max(seg).item()), sequential=True)
|
876 |
+
label_colors = cm.get_cmap(colormap).colors
|
877 |
+
label_colors = [[int(i*255) for i in l] for l in label_colors]
|
878 |
+
else:
|
879 |
+
custom_label_colors = True
|
880 |
+
# label_colors = matplotlib.cm.get_cmap(colormap).colors
|
881 |
+
# label_colors = [[int(i*255) for i in l] for l in label_colors]
|
882 |
+
# print('label_colors', label_colors)
|
883 |
+
|
884 |
+
# label_colors = [
|
885 |
+
# (0, 0, 0), # None
|
886 |
+
# (70, 70, 70), # Buildings
|
887 |
+
# (190, 153, 153), # Fences
|
888 |
+
# (72, 0, 90), # Other
|
889 |
+
# (220, 20, 60), # Pedestrians
|
890 |
+
# (153, 153, 153), # Poles
|
891 |
+
# (157, 234, 50), # RoadLines
|
892 |
+
# (128, 64, 128), # Roads
|
893 |
+
# (244, 35, 232), # Sidewalks
|
894 |
+
# (107, 142, 35), # Vegetation
|
895 |
+
# (0, 0, 255), # Vehicles
|
896 |
+
# (102, 102, 156), # Walls
|
897 |
+
# (220, 220, 0) # TrafficSigns
|
898 |
+
# ]
|
899 |
+
|
900 |
+
r = torch.zeros_like(seg,dtype=torch.uint8)
|
901 |
+
g = torch.zeros_like(seg,dtype=torch.uint8)
|
902 |
+
b = torch.zeros_like(seg,dtype=torch.uint8)
|
903 |
+
|
904 |
+
for label in range(0,len(label_colors)):
|
905 |
+
if (not custom_label_colors):# and (N > 20):
|
906 |
+
label_ = label % 20
|
907 |
+
else:
|
908 |
+
label_ = label
|
909 |
+
|
910 |
+
idx = (seg == label+1)
|
911 |
+
r[idx] = label_colors[label_][0]
|
912 |
+
g[idx] = label_colors[label_][1]
|
913 |
+
b[idx] = label_colors[label_][2]
|
914 |
+
|
915 |
+
rgb = torch.stack([r,g,b],axis=1)
|
916 |
+
return self.summ_rgb(name,rgb,only_return=only_return, frame_id=frame_id)
|
917 |
+
|
918 |
+
def summ_pts_on_rgb(self, name, trajs, rgb, valids=None, frame_id=None, only_return=False, show_dots=True, cmap='coolwarm', linewidth=1):
|
919 |
+
# trajs is B, S, N, 2
|
920 |
+
# rgbs is B, S, C, H, W
|
921 |
+
B, C, H, W = rgb.shape
|
922 |
+
B, S, N, D = trajs.shape
|
923 |
+
|
924 |
+
rgb = rgb[0] # C, H, W
|
925 |
+
trajs = trajs[0] # S, N, 2
|
926 |
+
if valids is None:
|
927 |
+
valids = torch.ones_like(trajs[:,:,0]) # S, N
|
928 |
+
else:
|
929 |
+
valids = valids[0]
|
930 |
+
# print('trajs', trajs.shape)
|
931 |
+
# print('valids', valids.shape)
|
932 |
+
|
933 |
+
rgb = back2color(rgb).detach().cpu().numpy()
|
934 |
+
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
|
935 |
+
|
936 |
+
trajs = trajs.long().detach().cpu().numpy() # S, N, 2
|
937 |
+
valids = valids.long().detach().cpu().numpy() # S, N
|
938 |
+
|
939 |
+
rgb = rgb.astype(np.uint8).copy()
|
940 |
+
|
941 |
+
for i in range(N):
|
942 |
+
if cmap=='onediff' and i==0:
|
943 |
+
cmap_ = 'spring'
|
944 |
+
elif cmap=='onediff':
|
945 |
+
cmap_ = 'winter'
|
946 |
+
else:
|
947 |
+
cmap_ = cmap
|
948 |
+
traj = trajs[:,i] # S,2
|
949 |
+
valid = valids[:,i] # S
|
950 |
+
|
951 |
+
color_map = cm.get_cmap(cmap)
|
952 |
+
color = np.array(color_map(i)[:3]) * 255 # rgb
|
953 |
+
for s in range(S):
|
954 |
+
if valid[s]:
|
955 |
+
cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, color, -1)
|
956 |
+
rgb = torch.from_numpy(rgb).permute(2,0,1).unsqueeze(0)
|
957 |
+
rgb = preprocess_color(rgb)
|
958 |
+
return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id)
|
959 |
+
|
960 |
+
def summ_pts_on_rgbs(self, name, trajs, rgbs, valids=None, frame_ids=None, only_return=False, show_dots=True, cmap='coolwarm', linewidth=1):
|
961 |
+
# trajs is B, S, N, 2
|
962 |
+
# rgbs is B, S, C, H, W
|
963 |
+
B, S, C, H, W = rgbs.shape
|
964 |
+
B, S2, N, D = trajs.shape
|
965 |
+
assert(S==S2)
|
966 |
+
|
967 |
+
rgbs = rgbs[0] # S, C, H, W
|
968 |
+
trajs = trajs[0] # S, N, 2
|
969 |
+
if valids is None:
|
970 |
+
valids = torch.ones_like(trajs[:,:,0]) # S, N
|
971 |
+
else:
|
972 |
+
valids = valids[0]
|
973 |
+
# print('trajs', trajs.shape)
|
974 |
+
# print('valids', valids.shape)
|
975 |
+
|
976 |
+
rgbs_color = []
|
977 |
+
for rgb in rgbs:
|
978 |
+
rgb = back2color(rgb).detach().cpu().numpy()
|
979 |
+
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
|
980 |
+
rgbs_color.append(rgb) # each element 3 x H x W
|
981 |
+
|
982 |
+
trajs = trajs.long().detach().cpu().numpy() # S, N, 2
|
983 |
+
valids = valids.long().detach().cpu().numpy() # S, N
|
984 |
+
|
985 |
+
rgbs_color = [rgb.astype(np.uint8).copy() for rgb in rgbs_color]
|
986 |
+
|
987 |
+
for i in range(N):
|
988 |
+
traj = trajs[:,i] # S,2
|
989 |
+
valid = valids[:,i] # S
|
990 |
+
|
991 |
+
color_map = cm.get_cmap(cmap)
|
992 |
+
color = np.array(color_map(0)[:3]) * 255 # rgb
|
993 |
+
for s in range(S):
|
994 |
+
if valid[s]:
|
995 |
+
cv2.circle(rgbs_color[s], (traj[s,0], traj[s,1]), linewidth, color, -1)
|
996 |
+
rgbs = []
|
997 |
+
for rgb in rgbs_color:
|
998 |
+
rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
|
999 |
+
rgbs.append(preprocess_color(rgb))
|
1000 |
+
|
1001 |
+
return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids)
|
1002 |
+
|
1003 |
+
|
1004 |
+
def summ_traj2ds_on_rgbs(self, name, trajs, rgbs, valids=None, frame_ids=None, only_return=False, show_dots=False, cmap='coolwarm', vals=None, linewidth=1):
|
1005 |
+
# trajs is B, S, N, 2
|
1006 |
+
# rgbs is B, S, C, H, W
|
1007 |
+
B, S, C, H, W = rgbs.shape
|
1008 |
+
B, S2, N, D = trajs.shape
|
1009 |
+
assert(S==S2)
|
1010 |
+
|
1011 |
+
rgbs = rgbs[0] # S, C, H, W
|
1012 |
+
trajs = trajs[0] # S, N, 2
|
1013 |
+
if valids is None:
|
1014 |
+
valids = torch.ones_like(trajs[:,:,0]) # S, N
|
1015 |
+
else:
|
1016 |
+
valids = valids[0]
|
1017 |
+
|
1018 |
+
# print('trajs', trajs.shape)
|
1019 |
+
# print('valids', valids.shape)
|
1020 |
+
|
1021 |
+
if vals is not None:
|
1022 |
+
vals = vals[0] # N
|
1023 |
+
# print('vals', vals.shape)
|
1024 |
+
|
1025 |
+
rgbs_color = []
|
1026 |
+
for rgb in rgbs:
|
1027 |
+
rgb = back2color(rgb).detach().cpu().numpy()
|
1028 |
+
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
|
1029 |
+
rgbs_color.append(rgb) # each element 3 x H x W
|
1030 |
+
|
1031 |
+
for i in range(N):
|
1032 |
+
if cmap=='onediff' and i==0:
|
1033 |
+
cmap_ = 'spring'
|
1034 |
+
elif cmap=='onediff':
|
1035 |
+
cmap_ = 'winter'
|
1036 |
+
else:
|
1037 |
+
cmap_ = cmap
|
1038 |
+
traj = trajs[:,i].long().detach().cpu().numpy() # S, 2
|
1039 |
+
valid = valids[:,i].long().detach().cpu().numpy() # S
|
1040 |
+
|
1041 |
+
# print('traj', traj.shape)
|
1042 |
+
# print('valid', valid.shape)
|
1043 |
+
|
1044 |
+
if vals is not None:
|
1045 |
+
# val = vals[:,i].float().detach().cpu().numpy() # []
|
1046 |
+
val = vals[i].float().detach().cpu().numpy() # []
|
1047 |
+
# print('val', val.shape)
|
1048 |
+
else:
|
1049 |
+
val = None
|
1050 |
+
|
1051 |
+
for t in range(S):
|
1052 |
+
# if valid[t]:
|
1053 |
+
# traj_seq = traj[max(t-16,0):t+1]
|
1054 |
+
traj_seq = traj[max(t-8,0):t+1]
|
1055 |
+
val_seq = np.linspace(0,1,len(traj_seq))
|
1056 |
+
# if t<2:
|
1057 |
+
# val_seq = np.zeros_like(val_seq)
|
1058 |
+
# print('val_seq', val_seq)
|
1059 |
+
# val_seq = 1.0
|
1060 |
+
# val_seq = np.arange(8)/8.0
|
1061 |
+
# val_seq = val_seq[-len(traj_seq):]
|
1062 |
+
# rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj_seq, S=S, show_dots=show_dots, cmap=cmap_, val=val_seq, linewidth=linewidth)
|
1063 |
+
rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj_seq, S=S, show_dots=show_dots, cmap=cmap_, val=val_seq, linewidth=linewidth)
|
1064 |
+
# input()
|
1065 |
+
|
1066 |
+
for i in range(N):
|
1067 |
+
if cmap=='onediff' and i==0:
|
1068 |
+
cmap_ = 'spring'
|
1069 |
+
elif cmap=='onediff':
|
1070 |
+
cmap_ = 'winter'
|
1071 |
+
else:
|
1072 |
+
cmap_ = cmap
|
1073 |
+
traj = trajs[:,i] # S,2
|
1074 |
+
# vis = visibles[:,i] # S
|
1075 |
+
vis = torch.ones_like(traj[:,0]) # S
|
1076 |
+
valid = valids[:,i] # S
|
1077 |
+
rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=0, show_dots=show_dots, cmap=cmap_, linewidth=linewidth)
|
1078 |
+
|
1079 |
+
rgbs = []
|
1080 |
+
for rgb in rgbs_color:
|
1081 |
+
rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
|
1082 |
+
rgbs.append(preprocess_color(rgb))
|
1083 |
+
|
1084 |
+
return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids)
|
1085 |
+
|
1086 |
+
def summ_traj2ds_on_rgbs2(self, name, trajs, visibles, rgbs, valids=None, frame_ids=None, only_return=False, show_dots=True, cmap=None, linewidth=1):
|
1087 |
+
# trajs is B, S, N, 2
|
1088 |
+
# rgbs is B, S, C, H, W
|
1089 |
+
B, S, C, H, W = rgbs.shape
|
1090 |
+
B, S2, N, D = trajs.shape
|
1091 |
+
assert(S==S2)
|
1092 |
+
|
1093 |
+
rgbs = rgbs[0] # S, C, H, W
|
1094 |
+
trajs = trajs[0] # S, N, 2
|
1095 |
+
visibles = visibles[0] # S, N
|
1096 |
+
if valids is None:
|
1097 |
+
valids = torch.ones_like(trajs[:,:,0]) # S, N
|
1098 |
+
else:
|
1099 |
+
valids = valids[0]
|
1100 |
+
# print('trajs', trajs.shape)
|
1101 |
+
# print('valids', valids.shape)
|
1102 |
+
|
1103 |
+
rgbs_color = []
|
1104 |
+
for rgb in rgbs:
|
1105 |
+
rgb = back2color(rgb).detach().cpu().numpy()
|
1106 |
+
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
|
1107 |
+
rgbs_color.append(rgb) # each element 3 x H x W
|
1108 |
+
|
1109 |
+
trajs = trajs.long().detach().cpu().numpy() # S, N, 2
|
1110 |
+
visibles = visibles.float().detach().cpu().numpy() # S, N
|
1111 |
+
valids = valids.long().detach().cpu().numpy() # S, N
|
1112 |
+
|
1113 |
+
for i in range(N):
|
1114 |
+
if cmap=='onediff' and i==0:
|
1115 |
+
cmap_ = 'spring'
|
1116 |
+
elif cmap=='onediff':
|
1117 |
+
cmap_ = 'winter'
|
1118 |
+
else:
|
1119 |
+
cmap_ = cmap
|
1120 |
+
traj = trajs[:,i] # S,2
|
1121 |
+
vis = visibles[:,i] # S
|
1122 |
+
valid = valids[:,i] # S
|
1123 |
+
rgbs_color = self.draw_traj_on_images_py(rgbs_color, traj, S=S, show_dots=show_dots, cmap=cmap_, linewidth=linewidth)
|
1124 |
+
|
1125 |
+
for i in range(N):
|
1126 |
+
if cmap=='onediff' and i==0:
|
1127 |
+
cmap_ = 'spring'
|
1128 |
+
elif cmap=='onediff':
|
1129 |
+
cmap_ = 'winter'
|
1130 |
+
else:
|
1131 |
+
cmap_ = cmap
|
1132 |
+
traj = trajs[:,i] # S,2
|
1133 |
+
vis = visibles[:,i] # S
|
1134 |
+
valid = valids[:,i] # S
|
1135 |
+
if valid[0]:
|
1136 |
+
rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=S, show_dots=show_dots, cmap=None, linewidth=linewidth)
|
1137 |
+
|
1138 |
+
rgbs = []
|
1139 |
+
for rgb in rgbs_color:
|
1140 |
+
rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
|
1141 |
+
rgbs.append(preprocess_color(rgb))
|
1142 |
+
|
1143 |
+
return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids)
|
1144 |
+
|
1145 |
+
def summ_traj2ds_on_rgb(self, name, trajs, rgb, valids=None, show_dots=False, show_lines=True, frame_id=None, only_return=False, cmap='coolwarm', linewidth=1):
|
1146 |
+
# trajs is B, S, N, 2
|
1147 |
+
# rgb is B, C, H, W
|
1148 |
+
B, C, H, W = rgb.shape
|
1149 |
+
B, S, N, D = trajs.shape
|
1150 |
+
|
1151 |
+
rgb = rgb[0] # S, C, H, W
|
1152 |
+
trajs = trajs[0] # S, N, 2
|
1153 |
+
|
1154 |
+
if valids is None:
|
1155 |
+
valids = torch.ones_like(trajs[:,:,0])
|
1156 |
+
else:
|
1157 |
+
valids = valids[0]
|
1158 |
+
|
1159 |
+
rgb_color = back2color(rgb).detach().cpu().numpy()
|
1160 |
+
rgb_color = np.transpose(rgb_color, [1, 2, 0]) # put channels last
|
1161 |
+
|
1162 |
+
# using maxdist will dampen the colors for short motions
|
1163 |
+
norms = torch.sqrt(1e-4 + torch.sum((trajs[-1] - trajs[0])**2, dim=1)) # N
|
1164 |
+
maxdist = torch.quantile(norms, 0.95).detach().cpu().numpy()
|
1165 |
+
maxdist = None
|
1166 |
+
trajs = trajs.long().detach().cpu().numpy() # S, N, 2
|
1167 |
+
valids = valids.long().detach().cpu().numpy() # S, N
|
1168 |
+
|
1169 |
+
for i in range(N):
|
1170 |
+
if cmap=='onediff' and i==0:
|
1171 |
+
cmap_ = 'spring'
|
1172 |
+
elif cmap=='onediff':
|
1173 |
+
cmap_ = 'winter'
|
1174 |
+
else:
|
1175 |
+
cmap_ = cmap
|
1176 |
+
traj = trajs[:,i] # S, 2
|
1177 |
+
valid = valids[:,i] # S
|
1178 |
+
if valid[0]==1:
|
1179 |
+
traj = traj[valid>0]
|
1180 |
+
rgb_color = self.draw_traj_on_image_py(
|
1181 |
+
rgb_color, traj, S=S, show_dots=show_dots, show_lines=show_lines, cmap=cmap_, maxdist=maxdist, linewidth=linewidth)
|
1182 |
+
|
1183 |
+
rgb_color = torch.from_numpy(rgb_color).permute(2, 0, 1).unsqueeze(0)
|
1184 |
+
rgb = preprocess_color(rgb_color)
|
1185 |
+
return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id)
|
1186 |
+
|
1187 |
+
def draw_traj_on_image_py(self, rgb, traj, S=50, linewidth=1, show_dots=False, show_lines=True, cmap='coolwarm', val=None, maxdist=None):
|
1188 |
+
# all inputs are numpy tensors
|
1189 |
+
# rgb is 3 x H x W
|
1190 |
+
# traj is S x 2
|
1191 |
+
|
1192 |
+
H, W, C = rgb.shape
|
1193 |
+
assert(C==3)
|
1194 |
+
|
1195 |
+
rgb = rgb.astype(np.uint8).copy()
|
1196 |
+
|
1197 |
+
S1, D = traj.shape
|
1198 |
+
assert(D==2)
|
1199 |
+
|
1200 |
+
color_map = cm.get_cmap(cmap)
|
1201 |
+
S1, D = traj.shape
|
1202 |
+
|
1203 |
+
for s in range(S1):
|
1204 |
+
if val is not None:
|
1205 |
+
# if len(val) == S1:
|
1206 |
+
color = np.array(color_map(val[s])[:3]) * 255 # rgb
|
1207 |
+
# else:
|
1208 |
+
# color = np.array(color_map(val)[:3]) * 255 # rgb
|
1209 |
+
else:
|
1210 |
+
if maxdist is not None:
|
1211 |
+
val = (np.sqrt(np.sum((traj[s]-traj[0])**2))/maxdist).clip(0,1)
|
1212 |
+
color = np.array(color_map(val)[:3]) * 255 # rgb
|
1213 |
+
else:
|
1214 |
+
color = np.array(color_map((s)/max(1,float(S-2)))[:3]) * 255 # rgb
|
1215 |
+
|
1216 |
+
if show_lines and s<(S1-1):
|
1217 |
+
cv2.line(rgb,
|
1218 |
+
(int(traj[s,0]), int(traj[s,1])),
|
1219 |
+
(int(traj[s+1,0]), int(traj[s+1,1])),
|
1220 |
+
color,
|
1221 |
+
linewidth,
|
1222 |
+
cv2.LINE_AA)
|
1223 |
+
if show_dots:
|
1224 |
+
cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, np.array(color_map(1)[:3])*255, -1)
|
1225 |
+
|
1226 |
+
# if maxdist is not None:
|
1227 |
+
# val = (np.sqrt(np.sum((traj[-1]-traj[0])**2))/maxdist).clip(0,1)
|
1228 |
+
# color = np.array(color_map(val)[:3]) * 255 # rgb
|
1229 |
+
# else:
|
1230 |
+
# # draw the endpoint of traj, using the next color (which may be the last color)
|
1231 |
+
# color = np.array(color_map((S1-1)/max(1,float(S-2)))[:3]) * 255 # rgb
|
1232 |
+
|
1233 |
+
# # emphasize endpoint
|
1234 |
+
# cv2.circle(rgb, (traj[-1,0], traj[-1,1]), linewidth*2, color, -1)
|
1235 |
+
|
1236 |
+
return rgb
|
1237 |
+
|
1238 |
+
|
1239 |
+
|
1240 |
+
def draw_traj_on_images_py(self, rgbs, traj, S=50, linewidth=1, show_dots=False, cmap='coolwarm', maxdist=None):
|
1241 |
+
# all inputs are numpy tensors
|
1242 |
+
# rgbs is a list of H,W,3
|
1243 |
+
# traj is S,2
|
1244 |
+
H, W, C = rgbs[0].shape
|
1245 |
+
assert(C==3)
|
1246 |
+
|
1247 |
+
rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs]
|
1248 |
+
|
1249 |
+
S1, D = traj.shape
|
1250 |
+
assert(D==2)
|
1251 |
+
|
1252 |
+
x = int(np.clip(traj[0,0], 0, W-1))
|
1253 |
+
y = int(np.clip(traj[0,1], 0, H-1))
|
1254 |
+
color = rgbs[0][y,x]
|
1255 |
+
color = (int(color[0]),int(color[1]),int(color[2]))
|
1256 |
+
for s in range(S):
|
1257 |
+
# bak_color = np.array(color_map(1.0)[:3]) * 255 # rgb
|
1258 |
+
# cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth*4, bak_color, -1)
|
1259 |
+
cv2.polylines(rgbs[s],
|
1260 |
+
[traj[:s+1]],
|
1261 |
+
False,
|
1262 |
+
color,
|
1263 |
+
linewidth,
|
1264 |
+
cv2.LINE_AA)
|
1265 |
+
return rgbs
|
1266 |
+
|
1267 |
+
def draw_circs_on_image_py(self, rgb, xy, colors=None, linewidth=10, radius=3, show_dots=False, maxdist=None):
|
1268 |
+
# all inputs are numpy tensors
|
1269 |
+
# rgbs is a list of 3,H,W
|
1270 |
+
# xy is N,2
|
1271 |
+
H, W, C = rgb.shape
|
1272 |
+
assert(C==3)
|
1273 |
+
|
1274 |
+
rgb = rgb.astype(np.uint8).copy()
|
1275 |
+
|
1276 |
+
N, D = xy.shape
|
1277 |
+
assert(D==2)
|
1278 |
+
|
1279 |
+
|
1280 |
+
xy = xy.astype(np.float32)
|
1281 |
+
xy[:,0] = np.clip(xy[:,0], 0, W-1)
|
1282 |
+
xy[:,1] = np.clip(xy[:,1], 0, H-1)
|
1283 |
+
xy = xy.astype(np.int32)
|
1284 |
+
|
1285 |
+
|
1286 |
+
|
1287 |
+
if colors is None:
|
1288 |
+
colors = get_n_colors(N)
|
1289 |
+
|
1290 |
+
for n in range(N):
|
1291 |
+
color = colors[n]
|
1292 |
+
# print('color', color)
|
1293 |
+
# color = (color[0]*255).astype(np.uint8)
|
1294 |
+
color = (int(color[0]),int(color[1]),int(color[2]))
|
1295 |
+
|
1296 |
+
# x = int(np.clip(xy[0,0], 0, W-1))
|
1297 |
+
# y = int(np.clip(xy[0,1], 0, H-1))
|
1298 |
+
# color_ = rgbs[0][y,x]
|
1299 |
+
# color_ = (int(color_[0]),int(color_[1]),int(color_[2]))
|
1300 |
+
# color_ = (int(color_[0]),int(color_[1]),int(color_[2]))
|
1301 |
+
|
1302 |
+
cv2.circle(rgb, (xy[n,0], xy[n,1]), linewidth, color, 3)
|
1303 |
+
# vis_color = int(np.squeeze(vis[s])*255)
|
1304 |
+
# vis_color = (vis_color,vis_color,vis_color)
|
1305 |
+
# cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth+1, vis_color, -1)
|
1306 |
+
return rgb
|
1307 |
+
|
1308 |
+
def draw_circ_on_images_py(self, rgbs, traj, vis, S=50, linewidth=1, show_dots=False, cmap=None, maxdist=None):
|
1309 |
+
# all inputs are numpy tensors
|
1310 |
+
# rgbs is a list of 3,H,W
|
1311 |
+
# traj is S,2
|
1312 |
+
H, W, C = rgbs[0].shape
|
1313 |
+
assert(C==3)
|
1314 |
+
|
1315 |
+
rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs]
|
1316 |
+
|
1317 |
+
S1, D = traj.shape
|
1318 |
+
assert(D==2)
|
1319 |
+
|
1320 |
+
if cmap is None:
|
1321 |
+
bremm = ColorMap2d()
|
1322 |
+
traj_ = traj[0:1].astype(np.float32)
|
1323 |
+
traj_[:,0] /= float(W)
|
1324 |
+
traj_[:,1] /= float(H)
|
1325 |
+
color = bremm(traj_)
|
1326 |
+
# print('color', color)
|
1327 |
+
color = (color[0]*255).astype(np.uint8)
|
1328 |
+
# color = (int(color[0]),int(color[1]),int(color[2]))
|
1329 |
+
color = (int(color[2]),int(color[1]),int(color[0]))
|
1330 |
+
|
1331 |
+
for s in range(S1):
|
1332 |
+
if cmap is not None:
|
1333 |
+
color_map = cm.get_cmap(cmap)
|
1334 |
+
# color = np.array(color_map(s/(S-1))[:3]) * 255 # rgb
|
1335 |
+
color = np.array(color_map((s+1)/max(1,float(S-1)))[:3]) * 255 # rgb
|
1336 |
+
# color = color.astype(np.uint8)
|
1337 |
+
# color = (color[0], color[1], color[2])
|
1338 |
+
# print('color', color)
|
1339 |
+
# import ipdb; ipdb.set_trace()
|
1340 |
+
|
1341 |
+
cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+1, color, -1)
|
1342 |
+
# vis_color = int(np.squeeze(vis[s])*255)
|
1343 |
+
# vis_color = (vis_color,vis_color,vis_color)
|
1344 |
+
# cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+1, vis_color, -1)
|
1345 |
+
|
1346 |
+
return rgbs
|
1347 |
+
|
1348 |
+
def summ_traj_as_crops(self, name, trajs_e, rgbs, frame_id=None, only_return=False, show_circ=False, trajs_g=None, is_g=False):
|
1349 |
+
B, S, N, D = trajs_e.shape
|
1350 |
+
assert(N==1)
|
1351 |
+
assert(D==2)
|
1352 |
+
|
1353 |
+
rgbs_vis = []
|
1354 |
+
n = 0
|
1355 |
+
pad_amount = 100
|
1356 |
+
trajs_e_py = trajs_e[0].detach().cpu().numpy()
|
1357 |
+
# trajs_e_py = np.clip(trajs_e_py, min=pad_amount/2, max=pad_amoun
|
1358 |
+
trajs_e_py = trajs_e_py + pad_amount
|
1359 |
+
|
1360 |
+
if trajs_g is not None:
|
1361 |
+
trajs_g_py = trajs_g[0].detach().cpu().numpy()
|
1362 |
+
trajs_g_py = trajs_g_py + pad_amount
|
1363 |
+
|
1364 |
+
for s in range(S):
|
1365 |
+
rgb = rgbs[0,s].detach().cpu().numpy()
|
1366 |
+
# print('orig rgb', rgb.shape)
|
1367 |
+
rgb = np.transpose(rgb,(1,2,0)) # H, W, 3
|
1368 |
+
|
1369 |
+
rgb = np.pad(rgb, ((pad_amount,pad_amount),(pad_amount,pad_amount),(0,0)))
|
1370 |
+
# print('pad rgb', rgb.shape)
|
1371 |
+
H, W, C = rgb.shape
|
1372 |
+
|
1373 |
+
if trajs_g is not None:
|
1374 |
+
xy_g = trajs_g_py[s,n]
|
1375 |
+
xy_g[0] = np.clip(xy_g[0], pad_amount, W-pad_amount)
|
1376 |
+
xy_g[1] = np.clip(xy_g[1], pad_amount, H-pad_amount)
|
1377 |
+
rgb = self.draw_circs_on_image_py(rgb, xy_g.reshape(1,2), colors=[(0,255,0)], linewidth=2, radius=3)
|
1378 |
+
|
1379 |
+
xy_e = trajs_e_py[s,n]
|
1380 |
+
xy_e[0] = np.clip(xy_e[0], pad_amount, W-pad_amount)
|
1381 |
+
xy_e[1] = np.clip(xy_e[1], pad_amount, H-pad_amount)
|
1382 |
+
|
1383 |
+
if show_circ:
|
1384 |
+
if is_g:
|
1385 |
+
rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(0,255,0)], linewidth=2, radius=3)
|
1386 |
+
else:
|
1387 |
+
rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(255,0,255)], linewidth=2, radius=3)
|
1388 |
+
|
1389 |
+
|
1390 |
+
xmin = int(xy_e[0])-pad_amount//2
|
1391 |
+
xmax = int(xy_e[0])+pad_amount//2
|
1392 |
+
ymin = int(xy_e[1])-pad_amount//2
|
1393 |
+
ymax = int(xy_e[1])+pad_amount//2
|
1394 |
+
|
1395 |
+
rgb_ = rgb[ymin:ymax, xmin:xmax]
|
1396 |
+
|
1397 |
+
H_, W_ = rgb_.shape[:2]
|
1398 |
+
# if np.any(rgb_.shape==0):
|
1399 |
+
# input()
|
1400 |
+
if H_==0 or W_==0:
|
1401 |
+
import ipdb; ipdb.set_trace()
|
1402 |
+
|
1403 |
+
rgb_ = rgb_.transpose(2,0,1)
|
1404 |
+
rgb_ = torch.from_numpy(rgb_)
|
1405 |
+
|
1406 |
+
rgbs_vis.append(rgb_)
|
1407 |
+
|
1408 |
+
# nrow = int(np.sqrt(S)*(16.0/9)/2.0)
|
1409 |
+
nrow = int(np.sqrt(S)*1.5)
|
1410 |
+
grid_img = torchvision.utils.make_grid(torch.stack(rgbs_vis, dim=0), nrow=nrow).unsqueeze(0)
|
1411 |
+
# print('grid_img', grid_img.shape)
|
1412 |
+
return self.summ_rgb(name, grid_img.byte(), frame_id=frame_id, only_return=only_return)
|
1413 |
+
|
1414 |
+
def summ_occ(self, name, occ, reduce_axes=[3], bev=False, fro=False, pro=False, frame_id=None, only_return=False):
|
1415 |
+
if self.save_this:
|
1416 |
+
B, C, D, H, W = list(occ.shape)
|
1417 |
+
if bev:
|
1418 |
+
reduce_axes = [3]
|
1419 |
+
elif fro:
|
1420 |
+
reduce_axes = [2]
|
1421 |
+
elif pro:
|
1422 |
+
reduce_axes = [4]
|
1423 |
+
for reduce_axis in reduce_axes:
|
1424 |
+
height = convert_occ_to_height(occ, reduce_axis=reduce_axis)
|
1425 |
+
if reduce_axis == reduce_axes[-1]:
|
1426 |
+
return self.summ_oned(name=('%s_ax%d' % (name, reduce_axis)), im=height, norm=False, frame_id=frame_id, only_return=only_return)
|
1427 |
+
else:
|
1428 |
+
self.summ_oned(name=('%s_ax%d' % (name, reduce_axis)), im=height, norm=False, frame_id=frame_id, only_return=only_return)
|
1429 |
+
|
1430 |
+
def erode2d(im, times=1, device='cuda'):
|
1431 |
+
weights2d = torch.ones(1, 1, 3, 3, device=device)
|
1432 |
+
for time in range(times):
|
1433 |
+
im = 1.0 - F.conv2d(1.0 - im, weights2d, padding=1).clamp(0, 1)
|
1434 |
+
return im
|
1435 |
+
|
1436 |
+
def dilate2d(im, times=1, device='cuda', mode='square'):
|
1437 |
+
weights2d = torch.ones(1, 1, 3, 3, device=device)
|
1438 |
+
if mode=='cross':
|
1439 |
+
weights2d[:,:,0,0] = 0.0
|
1440 |
+
weights2d[:,:,0,2] = 0.0
|
1441 |
+
weights2d[:,:,2,0] = 0.0
|
1442 |
+
weights2d[:,:,2,2] = 0.0
|
1443 |
+
for time in range(times):
|
1444 |
+
im = F.conv2d(im, weights2d, padding=1).clamp(0, 1)
|
1445 |
+
return im
|
1446 |
+
|
1447 |
+
|
models/spatracker/utils/misc.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import math
|
4 |
+
from prettytable import PrettyTable
|
5 |
+
|
6 |
+
def count_parameters(model):
|
7 |
+
table = PrettyTable(["Modules", "Parameters"])
|
8 |
+
total_params = 0
|
9 |
+
for name, parameter in model.named_parameters():
|
10 |
+
if not parameter.requires_grad:
|
11 |
+
continue
|
12 |
+
param = parameter.numel()
|
13 |
+
if param > 100000:
|
14 |
+
table.add_row([name, param])
|
15 |
+
total_params+=param
|
16 |
+
print(table)
|
17 |
+
print('total params: %.2f M' % (total_params/1000000.0))
|
18 |
+
return total_params
|
19 |
+
|
20 |
+
def posemb_sincos_2d_xy(xy, C, temperature=10000, dtype=torch.float32, cat_coords=False):
|
21 |
+
device = xy.device
|
22 |
+
dtype = xy.dtype
|
23 |
+
B, S, D = xy.shape
|
24 |
+
assert(D==2)
|
25 |
+
x = xy[:,:,0]
|
26 |
+
y = xy[:,:,1]
|
27 |
+
assert (C % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
|
28 |
+
omega = torch.arange(C // 4, device=device) / (C // 4 - 1)
|
29 |
+
omega = 1. / (temperature ** omega)
|
30 |
+
|
31 |
+
y = y.flatten()[:, None] * omega[None, :]
|
32 |
+
x = x.flatten()[:, None] * omega[None, :]
|
33 |
+
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
34 |
+
pe = pe.reshape(B,S,C).type(dtype)
|
35 |
+
if cat_coords:
|
36 |
+
pe = torch.cat([pe, xy], dim=2) # B,N,C+2
|
37 |
+
return pe
|
38 |
+
|
39 |
+
class SimplePool():
|
40 |
+
def __init__(self, pool_size, version='pt'):
|
41 |
+
self.pool_size = pool_size
|
42 |
+
self.version = version
|
43 |
+
self.items = []
|
44 |
+
|
45 |
+
if not (version=='pt' or version=='np'):
|
46 |
+
print('version = %s; please choose pt or np')
|
47 |
+
assert(False) # please choose pt or np
|
48 |
+
|
49 |
+
def __len__(self):
|
50 |
+
return len(self.items)
|
51 |
+
|
52 |
+
def mean(self, min_size=1):
|
53 |
+
if min_size=='half':
|
54 |
+
pool_size_thresh = self.pool_size/2
|
55 |
+
else:
|
56 |
+
pool_size_thresh = min_size
|
57 |
+
|
58 |
+
if self.version=='np':
|
59 |
+
if len(self.items) >= pool_size_thresh:
|
60 |
+
return np.sum(self.items)/float(len(self.items))
|
61 |
+
else:
|
62 |
+
return np.nan
|
63 |
+
if self.version=='pt':
|
64 |
+
if len(self.items) >= pool_size_thresh:
|
65 |
+
return torch.sum(self.items)/float(len(self.items))
|
66 |
+
else:
|
67 |
+
return torch.from_numpy(np.nan)
|
68 |
+
|
69 |
+
def sample(self, with_replacement=True):
|
70 |
+
idx = np.random.randint(len(self.items))
|
71 |
+
if with_replacement:
|
72 |
+
return self.items[idx]
|
73 |
+
else:
|
74 |
+
return self.items.pop(idx)
|
75 |
+
|
76 |
+
def fetch(self, num=None):
|
77 |
+
if self.version=='pt':
|
78 |
+
item_array = torch.stack(self.items)
|
79 |
+
elif self.version=='np':
|
80 |
+
item_array = np.stack(self.items)
|
81 |
+
if num is not None:
|
82 |
+
# there better be some items
|
83 |
+
assert(len(self.items) >= num)
|
84 |
+
|
85 |
+
# if there are not that many elements just return however many there are
|
86 |
+
if len(self.items) < num:
|
87 |
+
return item_array
|
88 |
+
else:
|
89 |
+
idxs = np.random.randint(len(self.items), size=num)
|
90 |
+
return item_array[idxs]
|
91 |
+
else:
|
92 |
+
return item_array
|
93 |
+
|
94 |
+
def is_full(self):
|
95 |
+
full = len(self.items)==self.pool_size
|
96 |
+
return full
|
97 |
+
|
98 |
+
def empty(self):
|
99 |
+
self.items = []
|
100 |
+
|
101 |
+
def update(self, items):
|
102 |
+
for item in items:
|
103 |
+
if len(self.items) < self.pool_size:
|
104 |
+
# the pool is not full, so let's add this in
|
105 |
+
self.items.append(item)
|
106 |
+
else:
|
107 |
+
# the pool is full
|
108 |
+
# pop from the front
|
109 |
+
self.items.pop(0)
|
110 |
+
# add to the back
|
111 |
+
self.items.append(item)
|
112 |
+
return self.items
|
113 |
+
|
114 |
+
def farthest_point_sample(xyz, npoint, include_ends=False, deterministic=False):
|
115 |
+
"""
|
116 |
+
Input:
|
117 |
+
xyz: pointcloud data, [B, N, C], where C is probably 3
|
118 |
+
npoint: number of samples
|
119 |
+
Return:
|
120 |
+
inds: sampled pointcloud index, [B, npoint]
|
121 |
+
"""
|
122 |
+
device = xyz.device
|
123 |
+
B, N, C = xyz.shape
|
124 |
+
xyz = xyz.float()
|
125 |
+
inds = torch.zeros(B, npoint, dtype=torch.long).to(device)
|
126 |
+
distance = torch.ones(B, N).to(device) * 1e10
|
127 |
+
if deterministic:
|
128 |
+
farthest = torch.randint(0, 1, (B,), dtype=torch.long).to(device)
|
129 |
+
else:
|
130 |
+
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
|
131 |
+
batch_indices = torch.arange(B, dtype=torch.long).to(device)
|
132 |
+
for i in range(npoint):
|
133 |
+
if include_ends:
|
134 |
+
if i==0:
|
135 |
+
farthest = 0
|
136 |
+
elif i==1:
|
137 |
+
farthest = N-1
|
138 |
+
inds[:, i] = farthest
|
139 |
+
centroid = xyz[batch_indices, farthest, :].view(B, 1, C)
|
140 |
+
dist = torch.sum((xyz - centroid) ** 2, -1)
|
141 |
+
mask = dist < distance
|
142 |
+
distance[mask] = dist[mask]
|
143 |
+
farthest = torch.max(distance, -1)[1]
|
144 |
+
|
145 |
+
if npoint > N:
|
146 |
+
# if we need more samples, make them random
|
147 |
+
distance += torch.randn_like(distance)
|
148 |
+
return inds
|
149 |
+
|
150 |
+
def farthest_point_sample_py(xyz, npoint):
|
151 |
+
N,C = xyz.shape
|
152 |
+
inds = np.zeros(npoint, dtype=np.int32)
|
153 |
+
distance = np.ones(N) * 1e10
|
154 |
+
farthest = np.random.randint(0, N, dtype=np.int32)
|
155 |
+
for i in range(npoint):
|
156 |
+
inds[i] = farthest
|
157 |
+
centroid = xyz[farthest, :].reshape(1,C)
|
158 |
+
dist = np.sum((xyz - centroid) ** 2, -1)
|
159 |
+
mask = dist < distance
|
160 |
+
distance[mask] = dist[mask]
|
161 |
+
farthest = np.argmax(distance, -1)
|
162 |
+
if npoint > N:
|
163 |
+
# if we need more samples, make them random
|
164 |
+
distance += np.random.randn(*distance.shape)
|
165 |
+
return inds
|
166 |
+
|
models/spatracker/utils/samp.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import utils.basic
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
def bilinear_sample2d(im, x, y, return_inbounds=False):
|
6 |
+
# x and y are each B, N
|
7 |
+
# output is B, C, N
|
8 |
+
B, C, H, W = list(im.shape)
|
9 |
+
N = list(x.shape)[1]
|
10 |
+
|
11 |
+
x = x.float()
|
12 |
+
y = y.float()
|
13 |
+
H_f = torch.tensor(H, dtype=torch.float32)
|
14 |
+
W_f = torch.tensor(W, dtype=torch.float32)
|
15 |
+
|
16 |
+
# inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()
|
17 |
+
|
18 |
+
max_y = (H_f - 1).int()
|
19 |
+
max_x = (W_f - 1).int()
|
20 |
+
|
21 |
+
x0 = torch.floor(x).int()
|
22 |
+
x1 = x0 + 1
|
23 |
+
y0 = torch.floor(y).int()
|
24 |
+
y1 = y0 + 1
|
25 |
+
|
26 |
+
x0_clip = torch.clamp(x0, 0, max_x)
|
27 |
+
x1_clip = torch.clamp(x1, 0, max_x)
|
28 |
+
y0_clip = torch.clamp(y0, 0, max_y)
|
29 |
+
y1_clip = torch.clamp(y1, 0, max_y)
|
30 |
+
dim2 = W
|
31 |
+
dim1 = W * H
|
32 |
+
|
33 |
+
base = torch.arange(0, B, dtype=torch.int64, device=x.device)*dim1
|
34 |
+
base = torch.reshape(base, [B, 1]).repeat([1, N])
|
35 |
+
|
36 |
+
base_y0 = base + y0_clip * dim2
|
37 |
+
base_y1 = base + y1_clip * dim2
|
38 |
+
|
39 |
+
idx_y0_x0 = base_y0 + x0_clip
|
40 |
+
idx_y0_x1 = base_y0 + x1_clip
|
41 |
+
idx_y1_x0 = base_y1 + x0_clip
|
42 |
+
idx_y1_x1 = base_y1 + x1_clip
|
43 |
+
|
44 |
+
# use the indices to lookup pixels in the flat image
|
45 |
+
# im is B x C x H x W
|
46 |
+
# move C out to last dim
|
47 |
+
im_flat = (im.permute(0, 2, 3, 1)).reshape(B*H*W, C)
|
48 |
+
i_y0_x0 = im_flat[idx_y0_x0.long()]
|
49 |
+
i_y0_x1 = im_flat[idx_y0_x1.long()]
|
50 |
+
i_y1_x0 = im_flat[idx_y1_x0.long()]
|
51 |
+
i_y1_x1 = im_flat[idx_y1_x1.long()]
|
52 |
+
|
53 |
+
# Finally calculate interpolated values.
|
54 |
+
x0_f = x0.float()
|
55 |
+
x1_f = x1.float()
|
56 |
+
y0_f = y0.float()
|
57 |
+
y1_f = y1.float()
|
58 |
+
|
59 |
+
w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
|
60 |
+
w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
|
61 |
+
w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
|
62 |
+
w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)
|
63 |
+
|
64 |
+
output = w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + \
|
65 |
+
w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1
|
66 |
+
# output is B*N x C
|
67 |
+
output = output.view(B, -1, C)
|
68 |
+
output = output.permute(0, 2, 1)
|
69 |
+
# output is B x C x N
|
70 |
+
|
71 |
+
if return_inbounds:
|
72 |
+
x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
|
73 |
+
y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
|
74 |
+
inbounds = (x_valid & y_valid).float()
|
75 |
+
inbounds = inbounds.reshape(B, N) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
|
76 |
+
return output, inbounds
|
77 |
+
|
78 |
+
return output # B, C, N
|
79 |
+
|
80 |
+
def paste_crop_on_canvas(crop, box2d_unnorm, H, W, fast=True, mask=None, canvas=None):
|
81 |
+
# this is the inverse of crop_and_resize_box2d
|
82 |
+
B, C, Y, X = list(crop.shape)
|
83 |
+
B2, D = list(box2d_unnorm.shape)
|
84 |
+
assert(B == B2)
|
85 |
+
assert(D == 4)
|
86 |
+
|
87 |
+
# here, we want to place the crop into a bigger image,
|
88 |
+
# at the location specified by the box2d.
|
89 |
+
|
90 |
+
if canvas is None:
|
91 |
+
canvas = torch.zeros((B, C, H, W), device=crop.device)
|
92 |
+
else:
|
93 |
+
B2, C2, H2, W2 = canvas.shape
|
94 |
+
assert(B==B2)
|
95 |
+
assert(C==C2)
|
96 |
+
assert(H==H2)
|
97 |
+
assert(W==W2)
|
98 |
+
|
99 |
+
# box2d_unnorm = utils.geom.unnormalize_box2d(box2d, H, W)
|
100 |
+
|
101 |
+
if fast:
|
102 |
+
ymin = box2d_unnorm[:, 0].long()
|
103 |
+
xmin = box2d_unnorm[:, 1].long()
|
104 |
+
ymax = box2d_unnorm[:, 2].long()
|
105 |
+
xmax = box2d_unnorm[:, 3].long()
|
106 |
+
w = (xmax - xmin).float()
|
107 |
+
h = (ymax - ymin).float()
|
108 |
+
|
109 |
+
grids = utils.basic.gridcloud2d(B, H, W)
|
110 |
+
grids_flat = grids.reshape(B, -1, 2)
|
111 |
+
# grids_flat[:, :, 0] = (grids_flat[:, :, 0] - xmin.float().unsqueeze(1)) / w.unsqueeze(1) * X
|
112 |
+
# grids_flat[:, :, 1] = (grids_flat[:, :, 1] - ymin.float().unsqueeze(1)) / h.unsqueeze(1) * Y
|
113 |
+
|
114 |
+
# for each pixel in the main image,
|
115 |
+
# grids_flat tells us where to sample in the crop image
|
116 |
+
|
117 |
+
# print('grids_flat', grids_flat.shape)
|
118 |
+
# print('crop', crop.shape)
|
119 |
+
|
120 |
+
grids_flat[:, :, 0] = (grids_flat[:, :, 0] - xmin.float().unsqueeze(1)) / w.unsqueeze(1) * 2.0 - 1.0
|
121 |
+
grids_flat[:, :, 1] = (grids_flat[:, :, 1] - ymin.float().unsqueeze(1)) / h.unsqueeze(1) * 2.0 - 1.0
|
122 |
+
|
123 |
+
grid = grids_flat.reshape(B,H,W,2)
|
124 |
+
|
125 |
+
canvas = F.grid_sample(crop, grid, align_corners=False)
|
126 |
+
# print('canvas', canvas.shape)
|
127 |
+
|
128 |
+
# if mask is None:
|
129 |
+
# crop_resamp, inb = bilinear_sample2d(crop, grids_flat[:, :, 0], grids_flat[:, :, 1], return_inbounds=True)
|
130 |
+
# crop_resamp = crop_resamp.reshape(B, C, H, W)
|
131 |
+
# inb = inb.reshape(B, 1, H, W)
|
132 |
+
# canvas = canvas * (1 - inb) + crop_resamp * inb
|
133 |
+
# else:
|
134 |
+
# full_resamp = bilinear_sample2d(torch.cat([crop, mask], dim=1), grids_flat[:, :, 0], grids_flat[:, :, 1])
|
135 |
+
# full_resamp = full_resamp.reshape(B, C+1, H, W)
|
136 |
+
# crop_resamp = full_resamp[:,:3]
|
137 |
+
# mask_resamp = full_resamp[:,3:4]
|
138 |
+
# canvas = canvas * (1 - mask_resamp) + crop_resamp * mask_resamp
|
139 |
+
else:
|
140 |
+
for b in range(B):
|
141 |
+
ymin = box2d_unnorm[b, 0].long()
|
142 |
+
xmin = box2d_unnorm[b, 1].long()
|
143 |
+
ymax = box2d_unnorm[b, 2].long()
|
144 |
+
xmax = box2d_unnorm[b, 3].long()
|
145 |
+
|
146 |
+
crop_b = F.interpolate(crop[b:b + 1], (ymax - ymin, xmax - xmin)).squeeze(0)
|
147 |
+
|
148 |
+
# print('canvas[b,:,...', canvas[b,:,ymin:ymax,xmin:xmax].shape)
|
149 |
+
# print('crop_b', crop_b.shape)
|
150 |
+
|
151 |
+
canvas[b, :, ymin:ymax, xmin:xmax] = crop_b
|
152 |
+
return canvas
|
models/spatracker/utils/visualizer.py
ADDED
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import os
|
8 |
+
import numpy as np
|
9 |
+
import cv2
|
10 |
+
import torch
|
11 |
+
import flow_vis
|
12 |
+
|
13 |
+
from matplotlib import cm
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import torchvision.transforms as transforms
|
16 |
+
from moviepy.editor import ImageSequenceClip
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
from tqdm import tqdm
|
19 |
+
|
20 |
+
def read_video_from_path(path):
|
21 |
+
cap = cv2.VideoCapture(path)
|
22 |
+
if not cap.isOpened():
|
23 |
+
print("Error opening video file")
|
24 |
+
else:
|
25 |
+
frames = []
|
26 |
+
while cap.isOpened():
|
27 |
+
ret, frame = cap.read()
|
28 |
+
if ret == True:
|
29 |
+
frames.append(np.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
|
30 |
+
else:
|
31 |
+
break
|
32 |
+
cap.release()
|
33 |
+
return np.stack(frames)
|
34 |
+
|
35 |
+
|
36 |
+
class Visualizer:
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
save_dir: str = "./results",
|
40 |
+
grayscale: bool = False,
|
41 |
+
pad_value: int = 0,
|
42 |
+
fps: int = 10,
|
43 |
+
mode: str = "rainbow", # 'cool', 'optical_flow'
|
44 |
+
linewidth: int = 1,
|
45 |
+
show_first_frame: int = 10,
|
46 |
+
tracks_leave_trace: int = 0, # -1 for infinite
|
47 |
+
):
|
48 |
+
self.mode = mode
|
49 |
+
self.save_dir = save_dir
|
50 |
+
self.vtxt_path = os.path.join(save_dir, "videos.txt")
|
51 |
+
self.ttxt_path = os.path.join(save_dir, "trackings.txt")
|
52 |
+
if mode == "rainbow":
|
53 |
+
self.color_map = cm.get_cmap("gist_rainbow")
|
54 |
+
elif mode == "cool":
|
55 |
+
self.color_map = cm.get_cmap(mode)
|
56 |
+
self.show_first_frame = show_first_frame
|
57 |
+
self.grayscale = grayscale
|
58 |
+
self.tracks_leave_trace = tracks_leave_trace
|
59 |
+
self.pad_value = pad_value
|
60 |
+
self.linewidth = linewidth
|
61 |
+
self.fps = fps
|
62 |
+
|
63 |
+
def visualize(
|
64 |
+
self,
|
65 |
+
video: torch.Tensor, # (B,T,C,H,W)
|
66 |
+
tracks: torch.Tensor, # (B,T,N,2)
|
67 |
+
visibility: torch.Tensor = None, # (B, T, N, 1) bool
|
68 |
+
gt_tracks: torch.Tensor = None, # (B,T,N,2)
|
69 |
+
segm_mask: torch.Tensor = None, # (B,1,H,W)
|
70 |
+
filename: str = "video",
|
71 |
+
writer=None, # tensorboard Summary Writer, used for visualization during training
|
72 |
+
step: int = 0,
|
73 |
+
query_frame: int = 0,
|
74 |
+
save_video: bool = True,
|
75 |
+
compensate_for_camera_motion: bool = False,
|
76 |
+
rigid_part = None,
|
77 |
+
video_depth = None # (B,T,C,H,W)
|
78 |
+
):
|
79 |
+
if compensate_for_camera_motion:
|
80 |
+
assert segm_mask is not None
|
81 |
+
if segm_mask is not None:
|
82 |
+
coords = tracks[0, query_frame].round().long()
|
83 |
+
segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
|
84 |
+
|
85 |
+
video = F.pad(
|
86 |
+
video,
|
87 |
+
(self.pad_value, self.pad_value, self.pad_value, self.pad_value),
|
88 |
+
"constant",
|
89 |
+
255,
|
90 |
+
)
|
91 |
+
|
92 |
+
if video_depth is not None:
|
93 |
+
video_depth = (video_depth*255).cpu().numpy().astype(np.uint8)
|
94 |
+
video_depth = ([cv2.applyColorMap(video_depth[0,i,0], cv2.COLORMAP_INFERNO)
|
95 |
+
for i in range(video_depth.shape[1])])
|
96 |
+
video_depth = np.stack(video_depth, axis=0)
|
97 |
+
video_depth = torch.from_numpy(video_depth).permute(0, 3, 1, 2)[None]
|
98 |
+
|
99 |
+
tracks = tracks + self.pad_value
|
100 |
+
|
101 |
+
if self.grayscale:
|
102 |
+
transform = transforms.Grayscale()
|
103 |
+
video = transform(video)
|
104 |
+
video = video.repeat(1, 1, 3, 1, 1)
|
105 |
+
|
106 |
+
tracking_video = self.draw_tracks_on_video(
|
107 |
+
video=video,
|
108 |
+
tracks=tracks,
|
109 |
+
visibility=visibility,
|
110 |
+
segm_mask=segm_mask,
|
111 |
+
gt_tracks=gt_tracks,
|
112 |
+
query_frame=query_frame,
|
113 |
+
compensate_for_camera_motion=compensate_for_camera_motion,
|
114 |
+
rigid_part=rigid_part
|
115 |
+
)
|
116 |
+
|
117 |
+
if save_video:
|
118 |
+
# import ipdb; ipdb.set_trace()
|
119 |
+
tracking_dir = os.path.join(self.save_dir, "tracking")
|
120 |
+
if not os.path.exists(tracking_dir):
|
121 |
+
os.makedirs(tracking_dir)
|
122 |
+
self.save_video(tracking_video, filename=filename+"_tracking",
|
123 |
+
savedir=tracking_dir, writer=writer, step=step)
|
124 |
+
# with open(self.ttxt_path, 'a') as file:
|
125 |
+
# file.write(f"tracking/{filename}_tracking.mp4\n")
|
126 |
+
|
127 |
+
videos_dir = os.path.join(self.save_dir, "videos")
|
128 |
+
if not os.path.exists(videos_dir):
|
129 |
+
os.makedirs(videos_dir)
|
130 |
+
self.save_video(video, filename=filename,
|
131 |
+
savedir=videos_dir, writer=writer, step=step)
|
132 |
+
# with open(self.vtxt_path, 'a') as file:
|
133 |
+
# file.write(f"videos/{filename}.mp4\n")
|
134 |
+
if video_depth is not None:
|
135 |
+
self.save_video(video_depth, filename=filename+"_depth",
|
136 |
+
savedir=os.path.join(self.save_dir, "depth"), writer=writer, step=step)
|
137 |
+
return tracking_video
|
138 |
+
|
139 |
+
def save_video(self, video, filename, savedir=None, writer=None, step=0):
|
140 |
+
if writer is not None:
|
141 |
+
writer.add_video(
|
142 |
+
f"{filename}",
|
143 |
+
video.to(torch.uint8),
|
144 |
+
global_step=step,
|
145 |
+
fps=self.fps,
|
146 |
+
)
|
147 |
+
else:
|
148 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
149 |
+
wide_list = list(video.unbind(1))
|
150 |
+
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
|
151 |
+
# clip = ImageSequenceClip(wide_list[2:-1], fps=self.fps)
|
152 |
+
clip = ImageSequenceClip(wide_list, fps=self.fps)
|
153 |
+
|
154 |
+
# Write the video file
|
155 |
+
if savedir is None:
|
156 |
+
save_path = os.path.join(self.save_dir, f"{filename}.mp4")
|
157 |
+
else:
|
158 |
+
save_path = os.path.join(savedir, f"{filename}.mp4")
|
159 |
+
clip.write_videofile(save_path, codec="libx264", fps=self.fps, logger=None)
|
160 |
+
|
161 |
+
print(f"Video saved to {save_path}")
|
162 |
+
|
163 |
+
def draw_tracks_on_video(
|
164 |
+
self,
|
165 |
+
video: torch.Tensor,
|
166 |
+
tracks: torch.Tensor,
|
167 |
+
visibility: torch.Tensor = None,
|
168 |
+
segm_mask: torch.Tensor = None,
|
169 |
+
gt_tracks=None,
|
170 |
+
query_frame: int = 0,
|
171 |
+
compensate_for_camera_motion=False,
|
172 |
+
rigid_part=None,
|
173 |
+
):
|
174 |
+
B, T, C, H, W = video.shape
|
175 |
+
_, _, N, D = tracks.shape
|
176 |
+
|
177 |
+
assert D == 3
|
178 |
+
assert C == 3
|
179 |
+
video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
|
180 |
+
tracks = tracks[0].detach().cpu().numpy() # S, N, 2
|
181 |
+
if gt_tracks is not None:
|
182 |
+
gt_tracks = gt_tracks[0].detach().cpu().numpy()
|
183 |
+
|
184 |
+
res_video = []
|
185 |
+
|
186 |
+
# process input video
|
187 |
+
# for rgb in video:
|
188 |
+
# res_video.append(rgb.copy())
|
189 |
+
|
190 |
+
# create a blank tensor with the same shape as the video
|
191 |
+
for rgb in video:
|
192 |
+
black_frame = np.zeros_like(rgb.copy(), dtype=rgb.dtype)
|
193 |
+
res_video.append(black_frame)
|
194 |
+
|
195 |
+
vector_colors = np.zeros((T, N, 3))
|
196 |
+
|
197 |
+
if self.mode == "optical_flow":
|
198 |
+
|
199 |
+
vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
|
200 |
+
|
201 |
+
elif segm_mask is None:
|
202 |
+
if self.mode == "rainbow":
|
203 |
+
x_min, x_max = tracks[0, :, 0].min(), tracks[0, :, 0].max()
|
204 |
+
y_min, y_max = tracks[0, :, 1].min(), tracks[0, :, 1].max()
|
205 |
+
|
206 |
+
z_inv = 1/tracks[0, :, 2]
|
207 |
+
z_min, z_max = np.percentile(z_inv, [2, 98])
|
208 |
+
|
209 |
+
norm_x = plt.Normalize(x_min, x_max)
|
210 |
+
norm_y = plt.Normalize(y_min, y_max)
|
211 |
+
norm_z = plt.Normalize(z_min, z_max)
|
212 |
+
|
213 |
+
for n in range(N):
|
214 |
+
r = norm_x(tracks[0, n, 0])
|
215 |
+
g = norm_y(tracks[0, n, 1])
|
216 |
+
# r = 0
|
217 |
+
# g = 0
|
218 |
+
b = norm_z(1/tracks[0, n, 2])
|
219 |
+
color = np.array([r, g, b])[None] * 255
|
220 |
+
vector_colors[:, n] = np.repeat(color, T, axis=0)
|
221 |
+
else:
|
222 |
+
# color changes with time
|
223 |
+
for t in range(T):
|
224 |
+
color = np.array(self.color_map(t / T)[:3])[None] * 255
|
225 |
+
vector_colors[t] = np.repeat(color, N, axis=0)
|
226 |
+
else:
|
227 |
+
if self.mode == "rainbow":
|
228 |
+
vector_colors[:, segm_mask <= 0, :] = 255
|
229 |
+
|
230 |
+
x_min, x_max = tracks[0, :, 0].min(), tracks[0, :, 0].max()
|
231 |
+
y_min, y_max = tracks[0, :, 1].min(), tracks[0, :, 1].max()
|
232 |
+
z_min, z_max = tracks[0, :, 2].min(), tracks[0, :, 2].max()
|
233 |
+
|
234 |
+
norm_x = plt.Normalize(x_min, x_max)
|
235 |
+
norm_y = plt.Normalize(y_min, y_max)
|
236 |
+
norm_z = plt.Normalize(z_min, z_max)
|
237 |
+
|
238 |
+
for n in range(N):
|
239 |
+
r = norm_x(tracks[0, n, 0])
|
240 |
+
g = norm_y(tracks[0, n, 1])
|
241 |
+
b = norm_z(tracks[0, n, 2])
|
242 |
+
color = np.array([r, g, b])[None] * 255
|
243 |
+
vector_colors[:, n] = np.repeat(color, T, axis=0)
|
244 |
+
|
245 |
+
else:
|
246 |
+
# color changes with segm class
|
247 |
+
segm_mask = segm_mask.cpu()
|
248 |
+
color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
|
249 |
+
color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
|
250 |
+
color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
|
251 |
+
vector_colors = np.repeat(color[None], T, axis=0)
|
252 |
+
|
253 |
+
# Draw tracks
|
254 |
+
if self.tracks_leave_trace != 0:
|
255 |
+
for t in range(1, T):
|
256 |
+
first_ind = (
|
257 |
+
max(0, t - self.tracks_leave_trace)
|
258 |
+
if self.tracks_leave_trace >= 0
|
259 |
+
else 0
|
260 |
+
)
|
261 |
+
curr_tracks = tracks[first_ind : t + 1]
|
262 |
+
curr_colors = vector_colors[first_ind : t + 1]
|
263 |
+
if compensate_for_camera_motion:
|
264 |
+
diff = (
|
265 |
+
tracks[first_ind : t + 1, segm_mask <= 0]
|
266 |
+
- tracks[t : t + 1, segm_mask <= 0]
|
267 |
+
).mean(1)[:, None]
|
268 |
+
|
269 |
+
curr_tracks = curr_tracks - diff
|
270 |
+
curr_tracks = curr_tracks[:, segm_mask > 0]
|
271 |
+
curr_colors = curr_colors[:, segm_mask > 0]
|
272 |
+
|
273 |
+
res_video[t] = self._draw_pred_tracks(
|
274 |
+
res_video[t],
|
275 |
+
curr_tracks,
|
276 |
+
curr_colors,
|
277 |
+
)
|
278 |
+
if gt_tracks is not None:
|
279 |
+
res_video[t] = self._draw_gt_tracks(
|
280 |
+
res_video[t], gt_tracks[first_ind : t + 1]
|
281 |
+
)
|
282 |
+
|
283 |
+
if rigid_part is not None:
|
284 |
+
cls_label = torch.unique(rigid_part)
|
285 |
+
cls_num = len(torch.unique(rigid_part))
|
286 |
+
# visualize the clustering results
|
287 |
+
cmap = plt.get_cmap('jet') # get the color mapping
|
288 |
+
colors = cmap(np.linspace(0, 1, cls_num))
|
289 |
+
colors = (colors[:, :3] * 255)
|
290 |
+
color_map = {lable.item(): color for lable, color in zip(cls_label, colors)}
|
291 |
+
|
292 |
+
# Draw points
|
293 |
+
for t in tqdm(range(T)):
|
294 |
+
# Create a list to store information for each point
|
295 |
+
points_info = []
|
296 |
+
for i in range(N):
|
297 |
+
coord = (tracks[t, i, 0], tracks[t, i, 1])
|
298 |
+
depth = tracks[t, i, 2] # assume the third dimension is depth
|
299 |
+
visibile = True
|
300 |
+
if visibility is not None:
|
301 |
+
visibile = visibility[0, t, i]
|
302 |
+
if coord[0] != 0 and coord[1] != 0:
|
303 |
+
if not compensate_for_camera_motion or (
|
304 |
+
compensate_for_camera_motion and segm_mask[i] > 0
|
305 |
+
):
|
306 |
+
points_info.append((i, coord, depth, visibile))
|
307 |
+
|
308 |
+
# Sort points by depth, points with smaller depth (closer) will be drawn later
|
309 |
+
points_info.sort(key=lambda x: x[2], reverse=True)
|
310 |
+
|
311 |
+
for i, coord, _, visibile in points_info:
|
312 |
+
if rigid_part is not None:
|
313 |
+
color = color_map[rigid_part.squeeze()[i].item()]
|
314 |
+
cv2.circle(
|
315 |
+
res_video[t],
|
316 |
+
coord,
|
317 |
+
int(self.linewidth * 2),
|
318 |
+
color.tolist(),
|
319 |
+
thickness=-1 if visibile else 2
|
320 |
+
-1,
|
321 |
+
)
|
322 |
+
else:
|
323 |
+
# Determine rectangle width based on the distance between adjacent tracks in the first frame
|
324 |
+
if t == 0:
|
325 |
+
distances = np.linalg.norm(tracks[0] - tracks[0, i], axis=1)
|
326 |
+
distances = distances[distances > 0]
|
327 |
+
rect_size = int(np.min(distances))/2
|
328 |
+
|
329 |
+
# Define coordinates for top-left and bottom-right corners of the rectangle
|
330 |
+
top_left = (int(coord[0] - rect_size), int(coord[1] - rect_size/1.5)) # Rectangle width is 1.5x (video aspect ratio is 1.5:1)
|
331 |
+
bottom_right = (int(coord[0] + rect_size), int(coord[1] + rect_size/1.5))
|
332 |
+
|
333 |
+
# Draw rectangle
|
334 |
+
cv2.rectangle(
|
335 |
+
res_video[t],
|
336 |
+
top_left,
|
337 |
+
bottom_right,
|
338 |
+
vector_colors[t, i].tolist(),
|
339 |
+
thickness=-1 if visibile else 0
|
340 |
+
-1,
|
341 |
+
)
|
342 |
+
|
343 |
+
# Construct the final rgb sequence
|
344 |
+
return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
|
345 |
+
|
346 |
+
def _draw_pred_tracks(
|
347 |
+
self,
|
348 |
+
rgb: np.ndarray, # H x W x 3
|
349 |
+
tracks: np.ndarray, # T x 2
|
350 |
+
vector_colors: np.ndarray,
|
351 |
+
alpha: float = 0.5,
|
352 |
+
):
|
353 |
+
T, N, _ = tracks.shape
|
354 |
+
|
355 |
+
for s in range(T - 1):
|
356 |
+
vector_color = vector_colors[s]
|
357 |
+
original = rgb.copy()
|
358 |
+
alpha = (s / T) ** 2
|
359 |
+
for i in range(N):
|
360 |
+
coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
|
361 |
+
coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
|
362 |
+
if coord_y[0] != 0 and coord_y[1] != 0:
|
363 |
+
cv2.line(
|
364 |
+
rgb,
|
365 |
+
coord_y,
|
366 |
+
coord_x,
|
367 |
+
vector_color[i].tolist(),
|
368 |
+
self.linewidth,
|
369 |
+
cv2.LINE_AA,
|
370 |
+
)
|
371 |
+
if self.tracks_leave_trace > 0:
|
372 |
+
rgb = cv2.addWeighted(rgb, alpha, original, 1 - alpha, 0)
|
373 |
+
return rgb
|
374 |
+
|
375 |
+
def _draw_gt_tracks(
|
376 |
+
self,
|
377 |
+
rgb: np.ndarray, # H x W x 3,
|
378 |
+
gt_tracks: np.ndarray, # T x 2
|
379 |
+
):
|
380 |
+
T, N, _ = gt_tracks.shape
|
381 |
+
color = np.array((211.0, 0.0, 0.0))
|
382 |
+
|
383 |
+
for t in range(T):
|
384 |
+
for i in range(N):
|
385 |
+
gt_tracks = gt_tracks[t][i]
|
386 |
+
# draw a red cross
|
387 |
+
if gt_tracks[0] > 0 and gt_tracks[1] > 0:
|
388 |
+
length = self.linewidth * 3
|
389 |
+
coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
|
390 |
+
coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
|
391 |
+
cv2.line(
|
392 |
+
rgb,
|
393 |
+
coord_y,
|
394 |
+
coord_x,
|
395 |
+
color,
|
396 |
+
self.linewidth,
|
397 |
+
cv2.LINE_AA,
|
398 |
+
)
|
399 |
+
coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
|
400 |
+
coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
|
401 |
+
cv2.line(
|
402 |
+
rgb,
|
403 |
+
coord_y,
|
404 |
+
coord_x,
|
405 |
+
color,
|
406 |
+
self.linewidth,
|
407 |
+
cv2.LINE_AA,
|
408 |
+
)
|
409 |
+
return rgb
|
models/spatracker/utils/vox.py
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import utils.geom
|
6 |
+
|
7 |
+
class Vox_util(object):
|
8 |
+
def __init__(self, Z, Y, X, scene_centroid, bounds, pad=None, assert_cube=False):
|
9 |
+
self.XMIN, self.XMAX, self.YMIN, self.YMAX, self.ZMIN, self.ZMAX = bounds
|
10 |
+
B, D = list(scene_centroid.shape)
|
11 |
+
self.Z, self.Y, self.X = Z, Y, X
|
12 |
+
|
13 |
+
scene_centroid = scene_centroid.detach().cpu().numpy()
|
14 |
+
x_centroid, y_centroid, z_centroid = scene_centroid[0]
|
15 |
+
self.XMIN += x_centroid
|
16 |
+
self.XMAX += x_centroid
|
17 |
+
self.YMIN += y_centroid
|
18 |
+
self.YMAX += y_centroid
|
19 |
+
self.ZMIN += z_centroid
|
20 |
+
self.ZMAX += z_centroid
|
21 |
+
|
22 |
+
self.default_vox_size_X = (self.XMAX-self.XMIN)/float(X)
|
23 |
+
self.default_vox_size_Y = (self.YMAX-self.YMIN)/float(Y)
|
24 |
+
self.default_vox_size_Z = (self.ZMAX-self.ZMIN)/float(Z)
|
25 |
+
|
26 |
+
if pad:
|
27 |
+
Z_pad, Y_pad, X_pad = pad
|
28 |
+
self.ZMIN -= self.default_vox_size_Z * Z_pad
|
29 |
+
self.ZMAX += self.default_vox_size_Z * Z_pad
|
30 |
+
self.YMIN -= self.default_vox_size_Y * Y_pad
|
31 |
+
self.YMAX += self.default_vox_size_Y * Y_pad
|
32 |
+
self.XMIN -= self.default_vox_size_X * X_pad
|
33 |
+
self.XMAX += self.default_vox_size_X * X_pad
|
34 |
+
|
35 |
+
if assert_cube:
|
36 |
+
# we assume cube voxels
|
37 |
+
if (not np.isclose(self.default_vox_size_X, self.default_vox_size_Y)) or (not np.isclose(self.default_vox_size_X, self.default_vox_size_Z)):
|
38 |
+
print('Z, Y, X', Z, Y, X)
|
39 |
+
print('bounds for this iter:',
|
40 |
+
'X = %.2f to %.2f' % (self.XMIN, self.XMAX),
|
41 |
+
'Y = %.2f to %.2f' % (self.YMIN, self.YMAX),
|
42 |
+
'Z = %.2f to %.2f' % (self.ZMIN, self.ZMAX),
|
43 |
+
)
|
44 |
+
print('self.default_vox_size_X', self.default_vox_size_X)
|
45 |
+
print('self.default_vox_size_Y', self.default_vox_size_Y)
|
46 |
+
print('self.default_vox_size_Z', self.default_vox_size_Z)
|
47 |
+
assert(np.isclose(self.default_vox_size_X, self.default_vox_size_Y))
|
48 |
+
assert(np.isclose(self.default_vox_size_X, self.default_vox_size_Z))
|
49 |
+
|
50 |
+
def Ref2Mem(self, xyz, Z, Y, X, assert_cube=False):
|
51 |
+
# xyz is B x N x 3, in ref coordinates
|
52 |
+
# transforms ref coordinates into mem coordinates
|
53 |
+
B, N, C = list(xyz.shape)
|
54 |
+
device = xyz.device
|
55 |
+
assert(C==3)
|
56 |
+
mem_T_ref = self.get_mem_T_ref(B, Z, Y, X, assert_cube=assert_cube, device=device)
|
57 |
+
xyz = utils.geom.apply_4x4(mem_T_ref, xyz)
|
58 |
+
return xyz
|
59 |
+
|
60 |
+
def Mem2Ref(self, xyz_mem, Z, Y, X, assert_cube=False):
|
61 |
+
# xyz is B x N x 3, in mem coordinates
|
62 |
+
# transforms mem coordinates into ref coordinates
|
63 |
+
B, N, C = list(xyz_mem.shape)
|
64 |
+
ref_T_mem = self.get_ref_T_mem(B, Z, Y, X, assert_cube=assert_cube, device=xyz_mem.device)
|
65 |
+
xyz_ref = utils.geom.apply_4x4(ref_T_mem, xyz_mem)
|
66 |
+
return xyz_ref
|
67 |
+
|
68 |
+
def get_mem_T_ref(self, B, Z, Y, X, assert_cube=False, device='cuda'):
|
69 |
+
vox_size_X = (self.XMAX-self.XMIN)/float(X)
|
70 |
+
vox_size_Y = (self.YMAX-self.YMIN)/float(Y)
|
71 |
+
vox_size_Z = (self.ZMAX-self.ZMIN)/float(Z)
|
72 |
+
|
73 |
+
if assert_cube:
|
74 |
+
if (not np.isclose(vox_size_X, vox_size_Y)) or (not np.isclose(vox_size_X, vox_size_Z)):
|
75 |
+
print('Z, Y, X', Z, Y, X)
|
76 |
+
print('bounds for this iter:',
|
77 |
+
'X = %.2f to %.2f' % (self.XMIN, self.XMAX),
|
78 |
+
'Y = %.2f to %.2f' % (self.YMIN, self.YMAX),
|
79 |
+
'Z = %.2f to %.2f' % (self.ZMIN, self.ZMAX),
|
80 |
+
)
|
81 |
+
print('vox_size_X', vox_size_X)
|
82 |
+
print('vox_size_Y', vox_size_Y)
|
83 |
+
print('vox_size_Z', vox_size_Z)
|
84 |
+
assert(np.isclose(vox_size_X, vox_size_Y))
|
85 |
+
assert(np.isclose(vox_size_X, vox_size_Z))
|
86 |
+
|
87 |
+
# translation
|
88 |
+
# (this makes the left edge of the leftmost voxel correspond to XMIN)
|
89 |
+
center_T_ref = utils.geom.eye_4x4(B, device=device)
|
90 |
+
center_T_ref[:,0,3] = -self.XMIN-vox_size_X/2.0
|
91 |
+
center_T_ref[:,1,3] = -self.YMIN-vox_size_Y/2.0
|
92 |
+
center_T_ref[:,2,3] = -self.ZMIN-vox_size_Z/2.0
|
93 |
+
|
94 |
+
# scaling
|
95 |
+
# (this makes the right edge of the rightmost voxel correspond to XMAX)
|
96 |
+
mem_T_center = utils.geom.eye_4x4(B, device=device)
|
97 |
+
mem_T_center[:,0,0] = 1./vox_size_X
|
98 |
+
mem_T_center[:,1,1] = 1./vox_size_Y
|
99 |
+
mem_T_center[:,2,2] = 1./vox_size_Z
|
100 |
+
mem_T_ref = utils.geom.matmul2(mem_T_center, center_T_ref)
|
101 |
+
|
102 |
+
return mem_T_ref
|
103 |
+
|
104 |
+
def get_ref_T_mem(self, B, Z, Y, X, assert_cube=False, device='cuda'):
|
105 |
+
mem_T_ref = self.get_mem_T_ref(B, Z, Y, X, assert_cube=assert_cube, device=device)
|
106 |
+
# note safe_inverse is inapplicable here,
|
107 |
+
# since the transform is nonrigid
|
108 |
+
ref_T_mem = mem_T_ref.inverse()
|
109 |
+
return ref_T_mem
|
110 |
+
|
111 |
+
def get_inbounds(self, xyz, Z, Y, X, already_mem=False, padding=0.0, assert_cube=False):
|
112 |
+
# xyz is B x N x 3
|
113 |
+
# padding should be 0 unless you are trying to account for some later cropping
|
114 |
+
if not already_mem:
|
115 |
+
xyz = self.Ref2Mem(xyz, Z, Y, X, assert_cube=assert_cube)
|
116 |
+
|
117 |
+
x = xyz[:,:,0]
|
118 |
+
y = xyz[:,:,1]
|
119 |
+
z = xyz[:,:,2]
|
120 |
+
|
121 |
+
x_valid = ((x-padding)>-0.5).byte() & ((x+padding)<float(X-0.5)).byte()
|
122 |
+
y_valid = ((y-padding)>-0.5).byte() & ((y+padding)<float(Y-0.5)).byte()
|
123 |
+
z_valid = ((z-padding)>-0.5).byte() & ((z+padding)<float(Z-0.5)).byte()
|
124 |
+
nonzero = (~(z==0.0)).byte()
|
125 |
+
|
126 |
+
inbounds = x_valid & y_valid & z_valid & nonzero
|
127 |
+
return inbounds.bool()
|
128 |
+
|
129 |
+
def voxelize_xyz(self, xyz_ref, Z, Y, X, already_mem=False, assert_cube=False, clean_eps=0):
|
130 |
+
B, N, D = list(xyz_ref.shape)
|
131 |
+
assert(D==3)
|
132 |
+
if already_mem:
|
133 |
+
xyz_mem = xyz_ref
|
134 |
+
else:
|
135 |
+
xyz_mem = self.Ref2Mem(xyz_ref, Z, Y, X, assert_cube=assert_cube)
|
136 |
+
xyz_zero = self.Ref2Mem(xyz_ref[:,0:1]*0, Z, Y, X, assert_cube=assert_cube)
|
137 |
+
vox = self.get_occupancy(xyz_mem, Z, Y, X, clean_eps=clean_eps, xyz_zero=xyz_zero)
|
138 |
+
return vox
|
139 |
+
|
140 |
+
def voxelize_xyz_and_feats(self, xyz_ref, feats, Z, Y, X, already_mem=False, assert_cube=False, clean_eps=0):
|
141 |
+
B, N, D = list(xyz_ref.shape)
|
142 |
+
B2, N2, D2 = list(feats.shape)
|
143 |
+
assert(D==3)
|
144 |
+
assert(B==B2)
|
145 |
+
assert(N==N2)
|
146 |
+
if already_mem:
|
147 |
+
xyz_mem = xyz_ref
|
148 |
+
else:
|
149 |
+
xyz_mem = self.Ref2Mem(xyz_ref, Z, Y, X, assert_cube=assert_cube)
|
150 |
+
xyz_zero = self.Ref2Mem(xyz_ref[:,0:1]*0, Z, Y, X, assert_cube=assert_cube)
|
151 |
+
feats = self.get_feat_occupancy(xyz_mem, feats, Z, Y, X, clean_eps=clean_eps, xyz_zero=xyz_zero)
|
152 |
+
return feats
|
153 |
+
|
154 |
+
def get_occupancy(self, xyz, Z, Y, X, clean_eps=0, xyz_zero=None):
|
155 |
+
# xyz is B x N x 3 and in mem coords
|
156 |
+
# we want to fill a voxel tensor with 1's at these inds
|
157 |
+
B, N, C = list(xyz.shape)
|
158 |
+
assert(C==3)
|
159 |
+
|
160 |
+
# these papers say simple 1/0 occupancy is ok:
|
161 |
+
# http://openaccess.thecvf.com/content_cvpr_2018/papers/Yang_PIXOR_Real-Time_3d_CVPR_2018_paper.pdf
|
162 |
+
# http://openaccess.thecvf.com/content_cvpr_2018/papers/Luo_Fast_and_Furious_CVPR_2018_paper.pdf
|
163 |
+
# cont fusion says they do 8-neighbor interp
|
164 |
+
# voxelnet does occupancy but with a bit of randomness in terms of the reflectance value i think
|
165 |
+
|
166 |
+
inbounds = self.get_inbounds(xyz, Z, Y, X, already_mem=True)
|
167 |
+
x, y, z = xyz[:,:,0], xyz[:,:,1], xyz[:,:,2]
|
168 |
+
mask = torch.zeros_like(x)
|
169 |
+
mask[inbounds] = 1.0
|
170 |
+
|
171 |
+
if xyz_zero is not None:
|
172 |
+
# only take points that are beyond a thresh of zero
|
173 |
+
dist = torch.norm(xyz_zero-xyz, dim=2)
|
174 |
+
mask[dist < 0.1] = 0
|
175 |
+
|
176 |
+
if clean_eps > 0:
|
177 |
+
# only take points that are already near centers
|
178 |
+
xyz_round = torch.round(xyz) # B, N, 3
|
179 |
+
dist = torch.norm(xyz_round - xyz, dim=2)
|
180 |
+
mask[dist > clean_eps] = 0
|
181 |
+
|
182 |
+
# set the invalid guys to zero
|
183 |
+
# we then need to zero out 0,0,0
|
184 |
+
# (this method seems a bit clumsy)
|
185 |
+
x = x*mask
|
186 |
+
y = y*mask
|
187 |
+
z = z*mask
|
188 |
+
|
189 |
+
x = torch.round(x)
|
190 |
+
y = torch.round(y)
|
191 |
+
z = torch.round(z)
|
192 |
+
x = torch.clamp(x, 0, X-1).int()
|
193 |
+
y = torch.clamp(y, 0, Y-1).int()
|
194 |
+
z = torch.clamp(z, 0, Z-1).int()
|
195 |
+
|
196 |
+
x = x.view(B*N)
|
197 |
+
y = y.view(B*N)
|
198 |
+
z = z.view(B*N)
|
199 |
+
|
200 |
+
dim3 = X
|
201 |
+
dim2 = X * Y
|
202 |
+
dim1 = X * Y * Z
|
203 |
+
|
204 |
+
base = torch.arange(0, B, dtype=torch.int32, device=xyz.device)*dim1
|
205 |
+
base = torch.reshape(base, [B, 1]).repeat([1, N]).view(B*N)
|
206 |
+
|
207 |
+
vox_inds = base + z * dim2 + y * dim3 + x
|
208 |
+
voxels = torch.zeros(B*Z*Y*X, device=xyz.device).float()
|
209 |
+
voxels[vox_inds.long()] = 1.0
|
210 |
+
# zero out the singularity
|
211 |
+
voxels[base.long()] = 0.0
|
212 |
+
voxels = voxels.reshape(B, 1, Z, Y, X)
|
213 |
+
# B x 1 x Z x Y x X
|
214 |
+
return voxels
|
215 |
+
|
216 |
+
def get_feat_occupancy(self, xyz, feat, Z, Y, X, clean_eps=0, xyz_zero=None):
|
217 |
+
# xyz is B x N x 3 and in mem coords
|
218 |
+
# feat is B x N x D
|
219 |
+
# we want to fill a voxel tensor with 1's at these inds
|
220 |
+
B, N, C = list(xyz.shape)
|
221 |
+
B2, N2, D2 = list(feat.shape)
|
222 |
+
assert(C==3)
|
223 |
+
assert(B==B2)
|
224 |
+
assert(N==N2)
|
225 |
+
|
226 |
+
# these papers say simple 1/0 occupancy is ok:
|
227 |
+
# http://openaccess.thecvf.com/content_cvpr_2018/papers/Yang_PIXOR_Real-Time_3d_CVPR_2018_paper.pdf
|
228 |
+
# http://openaccess.thecvf.com/content_cvpr_2018/papers/Luo_Fast_and_Furious_CVPR_2018_paper.pdf
|
229 |
+
# cont fusion says they do 8-neighbor interp
|
230 |
+
# voxelnet does occupancy but with a bit of randomness in terms of the reflectance value i think
|
231 |
+
|
232 |
+
inbounds = self.get_inbounds(xyz, Z, Y, X, already_mem=True)
|
233 |
+
x, y, z = xyz[:,:,0], xyz[:,:,1], xyz[:,:,2]
|
234 |
+
mask = torch.zeros_like(x)
|
235 |
+
mask[inbounds] = 1.0
|
236 |
+
|
237 |
+
if xyz_zero is not None:
|
238 |
+
# only take points that are beyond a thresh of zero
|
239 |
+
dist = torch.norm(xyz_zero-xyz, dim=2)
|
240 |
+
mask[dist < 0.1] = 0
|
241 |
+
|
242 |
+
if clean_eps > 0:
|
243 |
+
# only take points that are already near centers
|
244 |
+
xyz_round = torch.round(xyz) # B, N, 3
|
245 |
+
dist = torch.norm(xyz_round - xyz, dim=2)
|
246 |
+
mask[dist > clean_eps] = 0
|
247 |
+
|
248 |
+
# set the invalid guys to zero
|
249 |
+
# we then need to zero out 0,0,0
|
250 |
+
# (this method seems a bit clumsy)
|
251 |
+
x = x*mask # B, N
|
252 |
+
y = y*mask
|
253 |
+
z = z*mask
|
254 |
+
feat = feat*mask.unsqueeze(-1) # B, N, D
|
255 |
+
|
256 |
+
x = torch.round(x)
|
257 |
+
y = torch.round(y)
|
258 |
+
z = torch.round(z)
|
259 |
+
x = torch.clamp(x, 0, X-1).int()
|
260 |
+
y = torch.clamp(y, 0, Y-1).int()
|
261 |
+
z = torch.clamp(z, 0, Z-1).int()
|
262 |
+
|
263 |
+
# permute point orders
|
264 |
+
perm = torch.randperm(N)
|
265 |
+
x = x[:, perm]
|
266 |
+
y = y[:, perm]
|
267 |
+
z = z[:, perm]
|
268 |
+
feat = feat[:, perm]
|
269 |
+
|
270 |
+
x = x.view(B*N)
|
271 |
+
y = y.view(B*N)
|
272 |
+
z = z.view(B*N)
|
273 |
+
feat = feat.view(B*N, -1)
|
274 |
+
|
275 |
+
dim3 = X
|
276 |
+
dim2 = X * Y
|
277 |
+
dim1 = X * Y * Z
|
278 |
+
|
279 |
+
base = torch.arange(0, B, dtype=torch.int32, device=xyz.device)*dim1
|
280 |
+
base = torch.reshape(base, [B, 1]).repeat([1, N]).view(B*N)
|
281 |
+
|
282 |
+
vox_inds = base + z * dim2 + y * dim3 + x
|
283 |
+
feat_voxels = torch.zeros((B*Z*Y*X, D2), device=xyz.device).float()
|
284 |
+
feat_voxels[vox_inds.long()] = feat
|
285 |
+
# zero out the singularity
|
286 |
+
feat_voxels[base.long()] = 0.0
|
287 |
+
feat_voxels = feat_voxels.reshape(B, Z, Y, X, D2).permute(0, 4, 1, 2, 3)
|
288 |
+
# B x C x Z x Y x X
|
289 |
+
return feat_voxels
|
290 |
+
|
291 |
+
def unproject_image_to_mem(self, rgb_camB, pixB_T_camA, camB_T_camA, Z, Y, X, assert_cube=False, xyz_camA=None):
|
292 |
+
# rgb_camB is B x C x H x W
|
293 |
+
# pixB_T_camA is B x 4 x 4
|
294 |
+
|
295 |
+
# rgb lives in B pixel coords
|
296 |
+
# we want everything in A memory coords
|
297 |
+
|
298 |
+
# this puts each C-dim pixel in the rgb_camB
|
299 |
+
# along a ray in the voxelgrid
|
300 |
+
B, C, H, W = list(rgb_camB.shape)
|
301 |
+
|
302 |
+
if xyz_camA is None:
|
303 |
+
xyz_memA = utils.basic.gridcloud3d(B, Z, Y, X, norm=False, device=pixB_T_camA.device)
|
304 |
+
xyz_camA = self.Mem2Ref(xyz_memA, Z, Y, X, assert_cube=assert_cube)
|
305 |
+
|
306 |
+
xyz_camB = utils.geom.apply_4x4(camB_T_camA, xyz_camA)
|
307 |
+
z = xyz_camB[:,:,2]
|
308 |
+
|
309 |
+
xyz_pixB = utils.geom.apply_4x4(pixB_T_camA, xyz_camA)
|
310 |
+
normalizer = torch.unsqueeze(xyz_pixB[:,:,2], 2)
|
311 |
+
EPS=1e-6
|
312 |
+
# z = xyz_pixB[:,:,2]
|
313 |
+
xy_pixB = xyz_pixB[:,:,:2]/torch.clamp(normalizer, min=EPS)
|
314 |
+
# this is B x N x 2
|
315 |
+
# this is the (floating point) pixel coordinate of each voxel
|
316 |
+
x, y = xy_pixB[:,:,0], xy_pixB[:,:,1]
|
317 |
+
# these are B x N
|
318 |
+
|
319 |
+
x_valid = (x>-0.5).bool() & (x<float(W-0.5)).bool()
|
320 |
+
y_valid = (y>-0.5).bool() & (y<float(H-0.5)).bool()
|
321 |
+
z_valid = (z>0.0).bool()
|
322 |
+
valid_mem = (x_valid & y_valid & z_valid).reshape(B, 1, Z, Y, X).float()
|
323 |
+
|
324 |
+
if (0):
|
325 |
+
# handwritten version
|
326 |
+
values = torch.zeros([B, C, Z*Y*X], dtype=torch.float32)
|
327 |
+
for b in list(range(B)):
|
328 |
+
values[b] = utils.samp.bilinear_sample_single(rgb_camB[b], x_pixB[b], y_pixB[b])
|
329 |
+
else:
|
330 |
+
# native pytorch version
|
331 |
+
y_pixB, x_pixB = utils.basic.normalize_grid2d(y, x, H, W)
|
332 |
+
# since we want a 3d output, we need 5d tensors
|
333 |
+
z_pixB = torch.zeros_like(x)
|
334 |
+
xyz_pixB = torch.stack([x_pixB, y_pixB, z_pixB], axis=2)
|
335 |
+
rgb_camB = rgb_camB.unsqueeze(2)
|
336 |
+
xyz_pixB = torch.reshape(xyz_pixB, [B, Z, Y, X, 3])
|
337 |
+
values = F.grid_sample(rgb_camB, xyz_pixB, align_corners=False)
|
338 |
+
|
339 |
+
values = torch.reshape(values, (B, C, Z, Y, X))
|
340 |
+
values = values * valid_mem
|
341 |
+
return values
|
342 |
+
|
343 |
+
def warp_tiled_to_mem(self, rgb_tileB, pixB_T_camA, camB_T_camA, Z, Y, X, DMIN, DMAX, assert_cube=False):
|
344 |
+
# rgb_tileB is B,C,D,H,W
|
345 |
+
# pixB_T_camA is B,4,4
|
346 |
+
# camB_T_camA is B,4,4
|
347 |
+
|
348 |
+
# rgb_tileB lives in B pixel coords but it has been tiled across the Z dimension
|
349 |
+
# we want everything in A memory coords
|
350 |
+
|
351 |
+
# this resamples the so that each C-dim pixel in rgb_tilB
|
352 |
+
# is put into its correct place in the voxelgrid
|
353 |
+
# (using the pinhole camera model)
|
354 |
+
|
355 |
+
B, C, D, H, W = list(rgb_tileB.shape)
|
356 |
+
|
357 |
+
xyz_memA = utils.basic.gridcloud3d(B, Z, Y, X, norm=False, device=pixB_T_camA.device)
|
358 |
+
|
359 |
+
xyz_camA = self.Mem2Ref(xyz_memA, Z, Y, X, assert_cube=assert_cube)
|
360 |
+
|
361 |
+
xyz_camB = utils.geom.apply_4x4(camB_T_camA, xyz_camA)
|
362 |
+
z_camB = xyz_camB[:,:,2]
|
363 |
+
|
364 |
+
# rgb_tileB has depth=DMIN in tile 0, and depth=DMAX in tile D-1
|
365 |
+
z_tileB = (D-1.0) * (z_camB-float(DMIN)) / float(DMAX-DMIN)
|
366 |
+
|
367 |
+
xyz_pixB = utils.geom.apply_4x4(pixB_T_camA, xyz_camA)
|
368 |
+
normalizer = torch.unsqueeze(xyz_pixB[:,:,2], 2)
|
369 |
+
EPS=1e-6
|
370 |
+
# z = xyz_pixB[:,:,2]
|
371 |
+
xy_pixB = xyz_pixB[:,:,:2]/torch.clamp(normalizer, min=EPS)
|
372 |
+
# this is B x N x 2
|
373 |
+
# this is the (floating point) pixel coordinate of each voxel
|
374 |
+
x, y = xy_pixB[:,:,0], xy_pixB[:,:,1]
|
375 |
+
# these are B x N
|
376 |
+
|
377 |
+
x_valid = (x>-0.5).bool() & (x<float(W-0.5)).bool()
|
378 |
+
y_valid = (y>-0.5).bool() & (y<float(H-0.5)).bool()
|
379 |
+
z_valid = (z_camB>0.0).bool()
|
380 |
+
valid_mem = (x_valid & y_valid & z_valid).reshape(B, 1, Z, Y, X).float()
|
381 |
+
|
382 |
+
z_tileB, y_pixB, x_pixB = utils.basic.normalize_grid3d(z_tileB, y, x, D, H, W)
|
383 |
+
xyz_pixB = torch.stack([x_pixB, y_pixB, z_tileB], axis=2)
|
384 |
+
xyz_pixB = torch.reshape(xyz_pixB, [B, Z, Y, X, 3])
|
385 |
+
values = F.grid_sample(rgb_tileB, xyz_pixB, align_corners=False)
|
386 |
+
|
387 |
+
values = torch.reshape(values, (B, C, Z, Y, X))
|
388 |
+
values = values * valid_mem
|
389 |
+
return values
|
390 |
+
|
391 |
+
|
392 |
+
def apply_mem_T_ref_to_lrtlist(self, lrtlist_cam, Z, Y, X, assert_cube=False):
|
393 |
+
# lrtlist is B x N x 19, in cam coordinates
|
394 |
+
# transforms them into mem coordinates, including a scale change for the lengths
|
395 |
+
B, N, C = list(lrtlist_cam.shape)
|
396 |
+
assert(C==19)
|
397 |
+
mem_T_cam = self.get_mem_T_ref(B, Z, Y, X, assert_cube=assert_cube, device=lrtlist_cam.device)
|
398 |
+
|
399 |
+
def xyz2circles(self, xyz, radius, Z, Y, X, soft=True, already_mem=True, also_offset=False, grid=None):
|
400 |
+
# xyz is B x N x 3
|
401 |
+
# radius is B x N or broadcastably so
|
402 |
+
# output is B x N x Z x Y x X
|
403 |
+
B, N, D = list(xyz.shape)
|
404 |
+
assert(D==3)
|
405 |
+
if not already_mem:
|
406 |
+
xyz = self.Ref2Mem(xyz, Z, Y, X)
|
407 |
+
|
408 |
+
if grid is None:
|
409 |
+
grid_z, grid_y, grid_x = utils.basic.meshgrid3d(B, Z, Y, X, stack=False, norm=False, device=xyz.device)
|
410 |
+
# note the default stack is on -1
|
411 |
+
grid = torch.stack([grid_x, grid_y, grid_z], dim=1)
|
412 |
+
# this is B x 3 x Z x Y x X
|
413 |
+
|
414 |
+
xyz = xyz.reshape(B, N, 3, 1, 1, 1)
|
415 |
+
grid = grid.reshape(B, 1, 3, Z, Y, X)
|
416 |
+
# this is B x N x Z x Y x X
|
417 |
+
|
418 |
+
# round the xyzs, so that at least one value matches the grid perfectly,
|
419 |
+
# and we get a value of 1 there (since exp(0)==1)
|
420 |
+
xyz = xyz.round()
|
421 |
+
|
422 |
+
if torch.is_tensor(radius):
|
423 |
+
radius = radius.clamp(min=0.01)
|
424 |
+
|
425 |
+
if soft:
|
426 |
+
off = grid - xyz # B,N,3,Z,Y,X
|
427 |
+
# interpret radius as sigma
|
428 |
+
dist_grid = torch.sum(off**2, dim=2, keepdim=False)
|
429 |
+
# this is B x N x Z x Y x X
|
430 |
+
if torch.is_tensor(radius):
|
431 |
+
radius = radius.reshape(B, N, 1, 1, 1)
|
432 |
+
mask = torch.exp(-dist_grid/(2*radius*radius))
|
433 |
+
# zero out near zero
|
434 |
+
mask[mask < 0.001] = 0.0
|
435 |
+
# h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
|
436 |
+
# h[h < np.finfo(h.dtype).eps * h.max()] = 0
|
437 |
+
# return h
|
438 |
+
if also_offset:
|
439 |
+
return mask, off
|
440 |
+
else:
|
441 |
+
return mask
|
442 |
+
else:
|
443 |
+
assert(False) # something is wrong with this. come back later to debug
|
444 |
+
|
445 |
+
dist_grid = torch.norm(grid - xyz, dim=2, keepdim=False)
|
446 |
+
# this is 0 at/near the xyz, and increases by 1 for each voxel away
|
447 |
+
|
448 |
+
radius = radius.reshape(B, N, 1, 1, 1)
|
449 |
+
|
450 |
+
within_radius_mask = (dist_grid < radius).float()
|
451 |
+
within_radius_mask = torch.sum(within_radius_mask, dim=1, keepdim=True).clamp(0, 1)
|
452 |
+
return within_radius_mask
|
453 |
+
|
454 |
+
def xyz2circles_bev(self, xyz, radius, Z, Y, X, already_mem=True, also_offset=False):
|
455 |
+
# xyz is B x N x 3
|
456 |
+
# radius is B x N or broadcastably so
|
457 |
+
# output is B x N x Z x Y x X
|
458 |
+
B, N, D = list(xyz.shape)
|
459 |
+
assert(D==3)
|
460 |
+
if not already_mem:
|
461 |
+
xyz = self.Ref2Mem(xyz, Z, Y, X)
|
462 |
+
|
463 |
+
xz = torch.stack([xyz[:,:,0], xyz[:,:,2]], dim=2)
|
464 |
+
|
465 |
+
grid_z, grid_x = utils.basic.meshgrid2d(B, Z, X, stack=False, norm=False, device=xyz.device)
|
466 |
+
# note the default stack is on -1
|
467 |
+
grid = torch.stack([grid_x, grid_z], dim=1)
|
468 |
+
# this is B x 2 x Z x X
|
469 |
+
|
470 |
+
xz = xz.reshape(B, N, 2, 1, 1)
|
471 |
+
grid = grid.reshape(B, 1, 2, Z, X)
|
472 |
+
# these are ready to broadcast to B x N x Z x X
|
473 |
+
|
474 |
+
# round the points, so that at least one value matches the grid perfectly,
|
475 |
+
# and we get a value of 1 there (since exp(0)==1)
|
476 |
+
xz = xz.round()
|
477 |
+
|
478 |
+
if torch.is_tensor(radius):
|
479 |
+
radius = radius.clamp(min=0.01)
|
480 |
+
|
481 |
+
off = grid - xz # B,N,2,Z,X
|
482 |
+
# interpret radius as sigma
|
483 |
+
dist_grid = torch.sum(off**2, dim=2, keepdim=False)
|
484 |
+
# this is B x N x Z x X
|
485 |
+
if torch.is_tensor(radius):
|
486 |
+
radius = radius.reshape(B, N, 1, 1, 1)
|
487 |
+
mask = torch.exp(-dist_grid/(2*radius*radius))
|
488 |
+
# zero out near zero
|
489 |
+
mask[mask < 0.001] = 0.0
|
490 |
+
|
491 |
+
# add a Y dim
|
492 |
+
mask = mask.unsqueeze(-2)
|
493 |
+
off = off.unsqueeze(-2)
|
494 |
+
# # B,N,2,Z,1,X
|
495 |
+
|
496 |
+
if also_offset:
|
497 |
+
return mask, off
|
498 |
+
else:
|
499 |
+
return mask
|
500 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# spatrack
|
2 |
+
easydict==1.13
|
3 |
+
opencv-python==4.9.0.80
|
4 |
+
moviepy==1.0.3
|
5 |
+
flow-vis==0.1
|
6 |
+
matplotlib==3.8.3
|
7 |
+
einops==0.7.0
|
8 |
+
timm==0.6.7
|
9 |
+
scikit-image==0.22.0
|
10 |
+
scikit-learn==1.4.1.post1
|
11 |
+
cupy-cuda11x
|
12 |
+
accelerate
|
13 |
+
yt-dlp
|
14 |
+
pandas
|
15 |
+
|
16 |
+
# cogvideox
|
17 |
+
bitsandbytes
|
18 |
+
diffusers>=0.31.2
|
19 |
+
transformers>=4.45.2
|
20 |
+
hf_transfer>=0.1.8
|
21 |
+
peft>=0.12.0
|
22 |
+
decord>=0.6.0
|
23 |
+
wandb
|
24 |
+
torchao>=0.5.0
|
25 |
+
sentencepiece>=0.2.0
|
26 |
+
imageio-ffmpeg>=0.5.1
|
27 |
+
numpy>=1.26.4
|
28 |
+
git+https://github.com/asomoza/image_gen_aux.git
|
29 |
+
deepspeed
|
30 |
+
|
31 |
+
# submodules
|
32 |
+
-r submodules/MoGe/requirements.txt
|
submodules/MoGe/.gitignore
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Ignore Visual Studio temporary files, build results, and
|
2 |
+
## files generated by popular Visual Studio add-ons.
|
3 |
+
##
|
4 |
+
## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore
|
5 |
+
|
6 |
+
# User-specific files
|
7 |
+
*.rsuser
|
8 |
+
*.suo
|
9 |
+
*.user
|
10 |
+
*.userosscache
|
11 |
+
*.sln.docstates
|
12 |
+
|
13 |
+
# User-specific files (MonoDevelop/Xamarin Studio)
|
14 |
+
*.userprefs
|
15 |
+
|
16 |
+
# Mono auto generated files
|
17 |
+
mono_crash.*
|
18 |
+
|
19 |
+
# Build results
|
20 |
+
[Dd]ebug/
|
21 |
+
[Dd]ebugPublic/
|
22 |
+
[Rr]elease/
|
23 |
+
[Rr]eleases/
|
24 |
+
x64/
|
25 |
+
x86/
|
26 |
+
[Ww][Ii][Nn]32/
|
27 |
+
[Aa][Rr][Mm]/
|
28 |
+
[Aa][Rr][Mm]64/
|
29 |
+
bld/
|
30 |
+
[Bb]in/
|
31 |
+
[Oo]bj/
|
32 |
+
[Ll]og/
|
33 |
+
[Ll]ogs/
|
34 |
+
|
35 |
+
# Visual Studio 2015/2017 cache/options directory
|
36 |
+
.vs/
|
37 |
+
# Uncomment if you have tasks that create the project's static files in wwwroot
|
38 |
+
#wwwroot/
|
39 |
+
|
40 |
+
# Visual Studio 2017 auto generated files
|
41 |
+
Generated\ Files/
|
42 |
+
|
43 |
+
# MSTest test Results
|
44 |
+
[Tt]est[Rr]esult*/
|
45 |
+
[Bb]uild[Ll]og.*
|
46 |
+
|
47 |
+
# NUnit
|
48 |
+
*.VisualState.xml
|
49 |
+
TestResult.xml
|
50 |
+
nunit-*.xml
|
51 |
+
|
52 |
+
# Build Results of an ATL Project
|
53 |
+
[Dd]ebugPS/
|
54 |
+
[Rr]eleasePS/
|
55 |
+
dlldata.c
|
56 |
+
|
57 |
+
# Benchmark Results
|
58 |
+
BenchmarkDotNet.Artifacts/
|
59 |
+
|
60 |
+
# .NET Core
|
61 |
+
project.lock.json
|
62 |
+
project.fragment.lock.json
|
63 |
+
artifacts/
|
64 |
+
|
65 |
+
# ASP.NET Scaffolding
|
66 |
+
ScaffoldingReadMe.txt
|
67 |
+
|
68 |
+
# StyleCop
|
69 |
+
StyleCopReport.xml
|
70 |
+
|
71 |
+
# Files built by Visual Studio
|
72 |
+
*_i.c
|
73 |
+
*_p.c
|
74 |
+
*_h.h
|
75 |
+
*.ilk
|
76 |
+
*.meta
|
77 |
+
*.obj
|
78 |
+
*.iobj
|
79 |
+
*.pch
|
80 |
+
*.pdb
|
81 |
+
*.ipdb
|
82 |
+
*.pgc
|
83 |
+
*.pgd
|
84 |
+
*.rsp
|
85 |
+
*.sbr
|
86 |
+
*.tlb
|
87 |
+
*.tli
|
88 |
+
*.tlh
|
89 |
+
*.tmp
|
90 |
+
*.tmp_proj
|
91 |
+
*_wpftmp.csproj
|
92 |
+
*.log
|
93 |
+
*.tlog
|
94 |
+
*.vspscc
|
95 |
+
*.vssscc
|
96 |
+
.builds
|
97 |
+
*.pidb
|
98 |
+
*.svclog
|
99 |
+
*.scc
|
100 |
+
|
101 |
+
# Chutzpah Test files
|
102 |
+
_Chutzpah*
|
103 |
+
|
104 |
+
# Visual C++ cache files
|
105 |
+
ipch/
|
106 |
+
*.aps
|
107 |
+
*.ncb
|
108 |
+
*.opendb
|
109 |
+
*.opensdf
|
110 |
+
*.sdf
|
111 |
+
*.cachefile
|
112 |
+
*.VC.db
|
113 |
+
*.VC.VC.opendb
|
114 |
+
|
115 |
+
# Visual Studio profiler
|
116 |
+
*.psess
|
117 |
+
*.vsp
|
118 |
+
*.vspx
|
119 |
+
*.sap
|
120 |
+
|
121 |
+
# Visual Studio Trace Files
|
122 |
+
*.e2e
|
123 |
+
|
124 |
+
# TFS 2012 Local Workspace
|
125 |
+
$tf/
|
126 |
+
|
127 |
+
# Guidance Automation Toolkit
|
128 |
+
*.gpState
|
129 |
+
|
130 |
+
# ReSharper is a .NET coding add-in
|
131 |
+
_ReSharper*/
|
132 |
+
*.[Rr]e[Ss]harper
|
133 |
+
*.DotSettings.user
|
134 |
+
|
135 |
+
# TeamCity is a build add-in
|
136 |
+
_TeamCity*
|
137 |
+
|
138 |
+
# DotCover is a Code Coverage Tool
|
139 |
+
*.dotCover
|
140 |
+
|
141 |
+
# AxoCover is a Code Coverage Tool
|
142 |
+
.axoCover/*
|
143 |
+
!.axoCover/settings.json
|
144 |
+
|
145 |
+
# Coverlet is a free, cross platform Code Coverage Tool
|
146 |
+
coverage*.json
|
147 |
+
coverage*.xml
|
148 |
+
coverage*.info
|
149 |
+
|
150 |
+
# Visual Studio code coverage results
|
151 |
+
*.coverage
|
152 |
+
*.coveragexml
|
153 |
+
|
154 |
+
# NCrunch
|
155 |
+
_NCrunch_*
|
156 |
+
.*crunch*.local.xml
|
157 |
+
nCrunchTemp_*
|
158 |
+
|
159 |
+
# MightyMoose
|
160 |
+
*.mm.*
|
161 |
+
AutoTest.Net/
|
162 |
+
|
163 |
+
# Web workbench (sass)
|
164 |
+
.sass-cache/
|
165 |
+
|
166 |
+
# Installshield output folder
|
167 |
+
[Ee]xpress/
|
168 |
+
|
169 |
+
# DocProject is a documentation generator add-in
|
170 |
+
DocProject/buildhelp/
|
171 |
+
DocProject/Help/*.HxT
|
172 |
+
DocProject/Help/*.HxC
|
173 |
+
DocProject/Help/*.hhc
|
174 |
+
DocProject/Help/*.hhk
|
175 |
+
DocProject/Help/*.hhp
|
176 |
+
DocProject/Help/Html2
|
177 |
+
DocProject/Help/html
|
178 |
+
|
179 |
+
# Click-Once directory
|
180 |
+
publish/
|
181 |
+
|
182 |
+
# Publish Web Output
|
183 |
+
*.[Pp]ublish.xml
|
184 |
+
*.azurePubxml
|
185 |
+
# Note: Comment the next line if you want to checkin your web deploy settings,
|
186 |
+
# but database connection strings (with potential passwords) will be unencrypted
|
187 |
+
*.pubxml
|
188 |
+
*.publishproj
|
189 |
+
|
190 |
+
# Microsoft Azure Web App publish settings. Comment the next line if you want to
|
191 |
+
# checkin your Azure Web App publish settings, but sensitive information contained
|
192 |
+
# in these scripts will be unencrypted
|
193 |
+
PublishScripts/
|
194 |
+
|
195 |
+
# NuGet Packages
|
196 |
+
*.nupkg
|
197 |
+
# NuGet Symbol Packages
|
198 |
+
*.snupkg
|
199 |
+
# The packages folder can be ignored because of Package Restore
|
200 |
+
**/[Pp]ackages/*
|
201 |
+
# except build/, which is used as an MSBuild target.
|
202 |
+
!**/[Pp]ackages/build/
|
203 |
+
# Uncomment if necessary however generally it will be regenerated when needed
|
204 |
+
#!**/[Pp]ackages/repositories.config
|
205 |
+
# NuGet v3's project.json files produces more ignorable files
|
206 |
+
*.nuget.props
|
207 |
+
*.nuget.targets
|
208 |
+
|
209 |
+
# Microsoft Azure Build Output
|
210 |
+
csx/
|
211 |
+
*.build.csdef
|
212 |
+
|
213 |
+
# Microsoft Azure Emulator
|
214 |
+
ecf/
|
215 |
+
rcf/
|
216 |
+
|
217 |
+
# Windows Store app package directories and files
|
218 |
+
AppPackages/
|
219 |
+
BundleArtifacts/
|
220 |
+
Package.StoreAssociation.xml
|
221 |
+
_pkginfo.txt
|
222 |
+
*.appx
|
223 |
+
*.appxbundle
|
224 |
+
*.appxupload
|
225 |
+
|
226 |
+
# Visual Studio cache files
|
227 |
+
# files ending in .cache can be ignored
|
228 |
+
*.[Cc]ache
|
229 |
+
# but keep track of directories ending in .cache
|
230 |
+
!?*.[Cc]ache/
|
231 |
+
|
232 |
+
# Others
|
233 |
+
ClientBin/
|
234 |
+
~$*
|
235 |
+
*~
|
236 |
+
*.dbmdl
|
237 |
+
*.dbproj.schemaview
|
238 |
+
*.jfm
|
239 |
+
*.pfx
|
240 |
+
*.publishsettings
|
241 |
+
orleans.codegen.cs
|
242 |
+
|
243 |
+
# Including strong name files can present a security risk
|
244 |
+
# (https://github.com/github/gitignore/pull/2483#issue-259490424)
|
245 |
+
#*.snk
|
246 |
+
|
247 |
+
# Since there are multiple workflows, uncomment next line to ignore bower_components
|
248 |
+
# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
|
249 |
+
#bower_components/
|
250 |
+
|
251 |
+
# RIA/Silverlight projects
|
252 |
+
Generated_Code/
|
253 |
+
|
254 |
+
# Backup & report files from converting an old project file
|
255 |
+
# to a newer Visual Studio version. Backup files are not needed,
|
256 |
+
# because we have git ;-)
|
257 |
+
_UpgradeReport_Files/
|
258 |
+
Backup*/
|
259 |
+
UpgradeLog*.XML
|
260 |
+
UpgradeLog*.htm
|
261 |
+
ServiceFabricBackup/
|
262 |
+
*.rptproj.bak
|
263 |
+
|
264 |
+
# SQL Server files
|
265 |
+
*.mdf
|
266 |
+
*.ldf
|
267 |
+
*.ndf
|
268 |
+
|
269 |
+
# Business Intelligence projects
|
270 |
+
*.rdl.data
|
271 |
+
*.bim.layout
|
272 |
+
*.bim_*.settings
|
273 |
+
*.rptproj.rsuser
|
274 |
+
*- [Bb]ackup.rdl
|
275 |
+
*- [Bb]ackup ([0-9]).rdl
|
276 |
+
*- [Bb]ackup ([0-9][0-9]).rdl
|
277 |
+
|
278 |
+
# Microsoft Fakes
|
279 |
+
FakesAssemblies/
|
280 |
+
|
281 |
+
# GhostDoc plugin setting file
|
282 |
+
*.GhostDoc.xml
|
283 |
+
|
284 |
+
# Node.js Tools for Visual Studio
|
285 |
+
.ntvs_analysis.dat
|
286 |
+
node_modules/
|
287 |
+
|
288 |
+
# Visual Studio 6 build log
|
289 |
+
*.plg
|
290 |
+
|
291 |
+
# Visual Studio 6 workspace options file
|
292 |
+
*.opt
|
293 |
+
|
294 |
+
# Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
|
295 |
+
*.vbw
|
296 |
+
|
297 |
+
# Visual Studio 6 auto-generated project file (contains which files were open etc.)
|
298 |
+
*.vbp
|
299 |
+
|
300 |
+
# Visual Studio 6 workspace and project file (working project files containing files to include in project)
|
301 |
+
*.dsw
|
302 |
+
*.dsp
|
303 |
+
|
304 |
+
# Visual Studio 6 technical files
|
305 |
+
*.ncb
|
306 |
+
*.aps
|
307 |
+
|
308 |
+
# Visual Studio LightSwitch build output
|
309 |
+
**/*.HTMLClient/GeneratedArtifacts
|
310 |
+
**/*.DesktopClient/GeneratedArtifacts
|
311 |
+
**/*.DesktopClient/ModelManifest.xml
|
312 |
+
**/*.Server/GeneratedArtifacts
|
313 |
+
**/*.Server/ModelManifest.xml
|
314 |
+
_Pvt_Extensions
|
315 |
+
|
316 |
+
# Paket dependency manager
|
317 |
+
.paket/paket.exe
|
318 |
+
paket-files/
|
319 |
+
|
320 |
+
# FAKE - F# Make
|
321 |
+
.fake/
|
322 |
+
|
323 |
+
# CodeRush personal settings
|
324 |
+
.cr/personal
|
325 |
+
|
326 |
+
# Python Tools for Visual Studio (PTVS)
|
327 |
+
__pycache__/
|
328 |
+
*.pyc
|
329 |
+
|
330 |
+
# Cake - Uncomment if you are using it
|
331 |
+
# tools/**
|
332 |
+
# !tools/packages.config
|
333 |
+
|
334 |
+
# Tabs Studio
|
335 |
+
*.tss
|
336 |
+
|
337 |
+
# Telerik's JustMock configuration file
|
338 |
+
*.jmconfig
|
339 |
+
|
340 |
+
# BizTalk build output
|
341 |
+
*.btp.cs
|
342 |
+
*.btm.cs
|
343 |
+
*.odx.cs
|
344 |
+
*.xsd.cs
|
345 |
+
|
346 |
+
# OpenCover UI analysis results
|
347 |
+
OpenCover/
|
348 |
+
|
349 |
+
# Azure Stream Analytics local run output
|
350 |
+
ASALocalRun/
|
351 |
+
|
352 |
+
# MSBuild Binary and Structured Log
|
353 |
+
*.binlog
|
354 |
+
|
355 |
+
# NVidia Nsight GPU debugger configuration file
|
356 |
+
*.nvuser
|
357 |
+
|
358 |
+
# MFractors (Xamarin productivity tool) working folder
|
359 |
+
.mfractor/
|
360 |
+
|
361 |
+
# Local History for Visual Studio
|
362 |
+
.localhistory/
|
363 |
+
|
364 |
+
# Visual Studio History (VSHistory) files
|
365 |
+
.vshistory/
|
366 |
+
|
367 |
+
# BeatPulse healthcheck temp database
|
368 |
+
healthchecksdb
|
369 |
+
|
370 |
+
# Backup folder for Package Reference Convert tool in Visual Studio 2017
|
371 |
+
MigrationBackup/
|
372 |
+
|
373 |
+
# Ionide (cross platform F# VS Code tools) working folder
|
374 |
+
.ionide/
|
375 |
+
|
376 |
+
# Fody - auto-generated XML schema
|
377 |
+
FodyWeavers.xsd
|
378 |
+
|
379 |
+
# VS Code files for those working on multiple tools
|
380 |
+
.vscode/*
|
381 |
+
!.vscode/settings.json
|
382 |
+
!.vscode/tasks.json
|
383 |
+
!.vscode/launch.json
|
384 |
+
!.vscode/extensions.json
|
385 |
+
*.code-workspace
|
386 |
+
|
387 |
+
# Local History for Visual Studio Code
|
388 |
+
.history/
|
389 |
+
|
390 |
+
# Windows Installer files from build outputs
|
391 |
+
*.cab
|
392 |
+
*.msi
|
393 |
+
*.msix
|
394 |
+
*.msm
|
395 |
+
*.msp
|
396 |
+
|
397 |
+
# JetBrains Rider
|
398 |
+
*.sln.iml
|
399 |
+
|
400 |
+
# MoGe
|
401 |
+
/data
|
402 |
+
/download
|
403 |
+
/extract
|
404 |
+
/view_point_cloud
|
405 |
+
/view_depth_map
|
406 |
+
/blobcache
|
407 |
+
/snapshot
|
408 |
+
/reference_embeddings
|
409 |
+
/.msra_intern_s_toolkit
|
410 |
+
/debug
|
411 |
+
/workspace
|
412 |
+
/mlruns
|
413 |
+
/infer_output
|
414 |
+
/video_output
|
415 |
+
/eval_output
|
416 |
+
/.blobcache
|
417 |
+
/test_images
|
418 |
+
/test_videos
|
419 |
+
/vis
|
420 |
+
/videos
|
421 |
+
/raid
|
422 |
+
/blobmnt
|
423 |
+
/eval_dump
|
424 |
+
/pretrained
|
425 |
+
/.gradio
|
submodules/MoGe/CHANGELOG.md
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## 2024-11-28
|
2 |
+
### Added
|
3 |
+
- Supported user-provided camera FOV. See [scripts/infer.py](scripts/infer.py) --fov_x.
|
4 |
+
- Related issues: [#25](https://github.com/microsoft/MoGe/issues/25) and [#24](https://github.com/microsoft/MoGe/issues/24).
|
5 |
+
- Added inference scripts for panorama images. See [scripts/infer_panorama.py](scripts/infer_panorama.py).
|
6 |
+
- Related issue: [#19](https://github.com/microsoft/MoGe/issues/19).
|
7 |
+
|
8 |
+
### Fixed
|
9 |
+
- Suppressed unnecessary numpy runtime warnings.
|
10 |
+
- Specified recommended versions of requirements.
|
11 |
+
- Related issue: [#21](https://github.com/microsoft/MoGe/issues/21).
|
12 |
+
|
13 |
+
### Changed
|
14 |
+
- Moved `app.py` and `infer.py` to [scripts/](scripts/)
|
15 |
+
- Improved edge removal.
|
submodules/MoGe/CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Microsoft Open Source Code of Conduct
|
2 |
+
|
3 |
+
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
4 |
+
|
5 |
+
Resources:
|
6 |
+
|
7 |
+
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
|
8 |
+
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
9 |
+
- Contact [[email protected]](mailto:[email protected]) with questions or concerns
|
submodules/MoGe/LICENSE
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) Microsoft Corporation.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE
|
22 |
+
|
23 |
+
|
24 |
+
Apache License
|
25 |
+
Version 2.0, January 2004
|
26 |
+
http://www.apache.org/licenses/
|
27 |
+
|
28 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
29 |
+
|
30 |
+
1. Definitions.
|
31 |
+
|
32 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
33 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
34 |
+
|
35 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
36 |
+
the copyright owner that is granting the License.
|
37 |
+
|
38 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
39 |
+
other entities that control, are controlled by, or are under common
|
40 |
+
control with that entity. For the purposes of this definition,
|
41 |
+
"control" means (i) the power, direct or indirect, to cause the
|
42 |
+
direction or management of such entity, whether by contract or
|
43 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
44 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
45 |
+
|
46 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
47 |
+
exercising permissions granted by this License.
|
48 |
+
|
49 |
+
"Source" form shall mean the preferred form for making modifications,
|
50 |
+
including but not limited to software source code, documentation
|
51 |
+
source, and configuration files.
|
52 |
+
|
53 |
+
"Object" form shall mean any form resulting from mechanical
|
54 |
+
transformation or translation of a Source form, including but
|
55 |
+
not limited to compiled object code, generated documentation,
|
56 |
+
and conversions to other media types.
|
57 |
+
|
58 |
+
"Work" shall mean the work of authorship, whether in Source or
|
59 |
+
Object form, made available under the License, as indicated by a
|
60 |
+
copyright notice that is included in or attached to the work
|
61 |
+
(an example is provided in the Appendix below).
|
62 |
+
|
63 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
64 |
+
form, that is based on (or derived from) the Work and for which the
|
65 |
+
editorial revisions, annotations, elaborations, or other modifications
|
66 |
+
represent, as a whole, an original work of authorship. For the purposes
|
67 |
+
of this License, Derivative Works shall not include works that remain
|
68 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
69 |
+
the Work and Derivative Works thereof.
|
70 |
+
|
71 |
+
"Contribution" shall mean any work of authorship, including
|
72 |
+
the original version of the Work and any modifications or additions
|
73 |
+
to that Work or Derivative Works thereof, that is intentionally
|
74 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
75 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
76 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
77 |
+
means any form of electronic, verbal, or written communication sent
|
78 |
+
to the Licensor or its representatives, including but not limited to
|
79 |
+
communication on electronic mailing lists, source code control systems,
|
80 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
81 |
+
Licensor for the purpose of discussing and improving the Work, but
|
82 |
+
excluding communication that is conspicuously marked or otherwise
|
83 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
84 |
+
|
85 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
86 |
+
on behalf of whom a Contribution has been received by Licensor and
|
87 |
+
subsequently incorporated within the Work.
|
88 |
+
|
89 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
90 |
+
this License, each Contributor hereby grants to You a perpetual,
|
91 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
92 |
+
copyright license to reproduce, prepare Derivative Works of,
|
93 |
+
publicly display, publicly perform, sublicense, and distribute the
|
94 |
+
Work and such Derivative Works in Source or Object form.
|
95 |
+
|
96 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
97 |
+
this License, each Contributor hereby grants to You a perpetual,
|
98 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
99 |
+
(except as stated in this section) patent license to make, have made,
|
100 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
101 |
+
where such license applies only to those patent claims licensable
|
102 |
+
by such Contributor that are necessarily infringed by their
|
103 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
104 |
+
with the Work to which such Contribution(s) was submitted. If You
|
105 |
+
institute patent litigation against any entity (including a
|
106 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
107 |
+
or a Contribution incorporated within the Work constitutes direct
|
108 |
+
or contributory patent infringement, then any patent licenses
|
109 |
+
granted to You under this License for that Work shall terminate
|
110 |
+
as of the date such litigation is filed.
|
111 |
+
|
112 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
113 |
+
Work or Derivative Works thereof in any medium, with or without
|
114 |
+
modifications, and in Source or Object form, provided that You
|
115 |
+
meet the following conditions:
|
116 |
+
|
117 |
+
(a) You must give any other recipients of the Work or
|
118 |
+
Derivative Works a copy of this License; and
|
119 |
+
|
120 |
+
(b) You must cause any modified files to carry prominent notices
|
121 |
+
stating that You changed the files; and
|
122 |
+
|
123 |
+
(c) You must retain, in the Source form of any Derivative Works
|
124 |
+
that You distribute, all copyright, patent, trademark, and
|
125 |
+
attribution notices from the Source form of the Work,
|
126 |
+
excluding those notices that do not pertain to any part of
|
127 |
+
the Derivative Works; and
|
128 |
+
|
129 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
130 |
+
distribution, then any Derivative Works that You distribute must
|
131 |
+
include a readable copy of the attribution notices contained
|
132 |
+
within such NOTICE file, excluding those notices that do not
|
133 |
+
pertain to any part of the Derivative Works, in at least one
|
134 |
+
of the following places: within a NOTICE text file distributed
|
135 |
+
as part of the Derivative Works; within the Source form or
|
136 |
+
documentation, if provided along with the Derivative Works; or,
|
137 |
+
within a display generated by the Derivative Works, if and
|
138 |
+
wherever such third-party notices normally appear. The contents
|
139 |
+
of the NOTICE file are for informational purposes only and
|
140 |
+
do not modify the License. You may add Your own attribution
|
141 |
+
notices within Derivative Works that You distribute, alongside
|
142 |
+
or as an addendum to the NOTICE text from the Work, provided
|
143 |
+
that such additional attribution notices cannot be construed
|
144 |
+
as modifying the License.
|
145 |
+
|
146 |
+
You may add Your own copyright statement to Your modifications and
|
147 |
+
may provide additional or different license terms and conditions
|
148 |
+
for use, reproduction, or distribution of Your modifications, or
|
149 |
+
for any such Derivative Works as a whole, provided Your use,
|
150 |
+
reproduction, and distribution of the Work otherwise complies with
|
151 |
+
the conditions stated in this License.
|
152 |
+
|
153 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
154 |
+
any Contribution intentionally submitted for inclusion in the Work
|
155 |
+
by You to the Licensor shall be under the terms and conditions of
|
156 |
+
this License, without any additional terms or conditions.
|
157 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
158 |
+
the terms of any separate license agreement you may have executed
|
159 |
+
with Licensor regarding such Contributions.
|
160 |
+
|
161 |
+
6. Trademarks. This License does not grant permission to use the trade
|
162 |
+
names, trademarks, service marks, or product names of the Licensor,
|
163 |
+
except as required for reasonable and customary use in describing the
|
164 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
165 |
+
|
166 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
167 |
+
agreed to in writing, Licensor provides the Work (and each
|
168 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
169 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
170 |
+
implied, including, without limitation, any warranties or conditions
|
171 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
172 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
173 |
+
appropriateness of using or redistributing the Work and assume any
|
174 |
+
risks associated with Your exercise of permissions under this License.
|
175 |
+
|
176 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
177 |
+
whether in tort (including negligence), contract, or otherwise,
|
178 |
+
unless required by applicable law (such as deliberate and grossly
|
179 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
180 |
+
liable to You for damages, including any direct, indirect, special,
|
181 |
+
incidental, or consequential damages of any character arising as a
|
182 |
+
result of this License or out of the use or inability to use the
|
183 |
+
Work (including but not limited to damages for loss of goodwill,
|
184 |
+
work stoppage, computer failure or malfunction, or any and all
|
185 |
+
other commercial damages or losses), even if such Contributor
|
186 |
+
has been advised of the possibility of such damages.
|
187 |
+
|
188 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
189 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
190 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
191 |
+
or other liability obligations and/or rights consistent with this
|
192 |
+
License. However, in accepting such obligations, You may act only
|
193 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
194 |
+
of any other Contributor, and only if You agree to indemnify,
|
195 |
+
defend, and hold each Contributor harmless for any liability
|
196 |
+
incurred by, or claims asserted against, such Contributor by reason
|
197 |
+
of your accepting any such warranty or additional liability.
|
198 |
+
|
199 |
+
END OF TERMS AND CONDITIONS
|
200 |
+
|
201 |
+
APPENDIX: How to apply the Apache License to your work.
|
202 |
+
|
203 |
+
To apply the Apache License to your work, attach the following
|
204 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
205 |
+
replaced with your own identifying information. (Don't include
|
206 |
+
the brackets!) The text should be enclosed in the appropriate
|
207 |
+
comment syntax for the file format. We also recommend that a
|
208 |
+
file or class name and description of purpose be included on the
|
209 |
+
same "printed page" as the copyright notice for easier
|
210 |
+
identification within third-party archives.
|
211 |
+
|
212 |
+
Copyright [yyyy] [name of copyright owner]
|
213 |
+
|
214 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
215 |
+
you may not use this file except in compliance with the License.
|
216 |
+
You may obtain a copy of the License at
|
217 |
+
|
218 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
219 |
+
|
220 |
+
Unless required by applicable law or agreed to in writing, software
|
221 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
222 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
223 |
+
See the License for the specific language governing permissions and
|
224 |
+
limitations under the License.
|