Spaces:
Running
on
Zero
Running
on
Zero
Olga
commited on
Commit
·
5f9d349
0
Parent(s):
Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +37 -0
- .gitignore +159 -0
- LICENSE.txt +12 -0
- README.md +14 -0
- app.py +434 -0
- assets/examples/video_bakery.mp4 +3 -0
- assets/examples/video_flowers.mp4 +3 -0
- assets/examples/video_fruits.mp4 +3 -0
- assets/examples/video_plant.mp4 +3 -0
- assets/examples/video_salad.mp4 +3 -0
- assets/examples/video_tram.mp4 +3 -0
- assets/examples/video_tulips.mp4 +3 -0
- assets/video_fruits_ours_full.mp4 +3 -0
- configs/gs/base.yaml +51 -0
- configs/train.yaml +38 -0
- requirements.txt +32 -0
- source/EDGS.code-workspace +11 -0
- source/__init__.py +0 -0
- source/corr_init.py +682 -0
- source/corr_init_new.py +904 -0
- source/data_utils.py +28 -0
- source/losses.py +100 -0
- source/networks.py +52 -0
- source/timer.py +24 -0
- source/trainer.py +262 -0
- source/utils_aux.py +92 -0
- source/utils_preprocess.py +334 -0
- source/vggt_to_colmap.py +598 -0
- source/visualization.py +1072 -0
- submodules/RoMa/.gitignore +11 -0
- submodules/RoMa/LICENSE +21 -0
- submodules/RoMa/README.md +123 -0
- submodules/RoMa/data/.gitignore +2 -0
- submodules/RoMa/demo/demo_3D_effect.py +47 -0
- submodules/RoMa/demo/demo_fundamental.py +34 -0
- submodules/RoMa/demo/demo_match.py +50 -0
- submodules/RoMa/demo/demo_match_opencv_sift.py +43 -0
- submodules/RoMa/demo/demo_match_tiny.py +77 -0
- submodules/RoMa/demo/gif/.gitignore +2 -0
- submodules/RoMa/experiments/eval_roma_outdoor.py +57 -0
- submodules/RoMa/experiments/eval_tiny_roma_v1_outdoor.py +84 -0
- submodules/RoMa/experiments/roma_indoor.py +320 -0
- submodules/RoMa/experiments/train_roma_outdoor.py +307 -0
- submodules/RoMa/experiments/train_tiny_roma_v1_outdoor.py +498 -0
- submodules/RoMa/requirements.txt +14 -0
- submodules/RoMa/romatch/__init__.py +8 -0
- submodules/RoMa/romatch/benchmarks/__init__.py +6 -0
- submodules/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py +113 -0
- submodules/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py +106 -0
- submodules/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py +118 -0
.gitattributes
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.whl filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Local notebooks used for local debug
|
2 |
+
notebooks_local/
|
3 |
+
wandb/
|
4 |
+
|
5 |
+
# Gradio files
|
6 |
+
|
7 |
+
served_files/
|
8 |
+
|
9 |
+
# hidden folders
|
10 |
+
.*/**
|
11 |
+
|
12 |
+
# The rest is taken from https://github.com/Anttwo/SuGaR
|
13 |
+
*.pt
|
14 |
+
*.pth
|
15 |
+
output*
|
16 |
+
*.slurm
|
17 |
+
*.pyc
|
18 |
+
*.ply
|
19 |
+
*.obj
|
20 |
+
sugar_tests.ipynb
|
21 |
+
sugar_sh_tests.ipynb
|
22 |
+
|
23 |
+
# To remove
|
24 |
+
frosting*
|
25 |
+
extract_shell.py
|
26 |
+
train_frosting_refined.py
|
27 |
+
train_frosting.py
|
28 |
+
run_frosting_viewer.py
|
29 |
+
slurm_a100.sh
|
30 |
+
|
31 |
+
# Byte-compiled / optimized / DLL files
|
32 |
+
__pycache__/
|
33 |
+
*.py[cod]
|
34 |
+
*$py.class
|
35 |
+
|
36 |
+
# C extensions
|
37 |
+
*.so
|
38 |
+
|
39 |
+
# Distribution / packaging
|
40 |
+
.Python
|
41 |
+
build/
|
42 |
+
develop-eggs/
|
43 |
+
dist/
|
44 |
+
downloads/
|
45 |
+
eggs/
|
46 |
+
.eggs/
|
47 |
+
lib/
|
48 |
+
lib64/
|
49 |
+
parts/
|
50 |
+
sdist/
|
51 |
+
var/
|
52 |
+
pip-wheel-metadata/
|
53 |
+
share/python-wheels/
|
54 |
+
*.egg-info/
|
55 |
+
.installed.cfg
|
56 |
+
*.egg
|
57 |
+
MANIFEST
|
58 |
+
|
59 |
+
# PyInstaller
|
60 |
+
# Usually these files are written by a python script from a template
|
61 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
62 |
+
*.manifest
|
63 |
+
*.spec
|
64 |
+
|
65 |
+
# Installer logs
|
66 |
+
pip-log.txt
|
67 |
+
pip-delete-this-directory.txt
|
68 |
+
|
69 |
+
# Unit test / coverage reports
|
70 |
+
htmlcov/
|
71 |
+
.tox/
|
72 |
+
.nox/
|
73 |
+
.coverage
|
74 |
+
.coverage.*
|
75 |
+
.cache
|
76 |
+
nosetests.xml
|
77 |
+
coverage.xml
|
78 |
+
*.cover
|
79 |
+
*.py,cover
|
80 |
+
.hypothesis/
|
81 |
+
.pytest_cache/
|
82 |
+
|
83 |
+
# Translations
|
84 |
+
*.mo
|
85 |
+
*.pot
|
86 |
+
|
87 |
+
# Django stuff:
|
88 |
+
*.log
|
89 |
+
local_settings.py
|
90 |
+
db.sqlite3
|
91 |
+
db.sqlite3-journal
|
92 |
+
|
93 |
+
# Flask stuff:
|
94 |
+
instance/
|
95 |
+
.webassets-cache
|
96 |
+
|
97 |
+
# Scrapy stuff:
|
98 |
+
.scrapy
|
99 |
+
|
100 |
+
# Sphinx documentation
|
101 |
+
docs/_build/
|
102 |
+
|
103 |
+
# PyBuilder
|
104 |
+
target/
|
105 |
+
|
106 |
+
# Jupyter Notebook
|
107 |
+
.ipynb_checkpoints
|
108 |
+
|
109 |
+
# IPython
|
110 |
+
profile_default/
|
111 |
+
ipython_config.py
|
112 |
+
|
113 |
+
# pyenv
|
114 |
+
.python-version
|
115 |
+
|
116 |
+
# pipenv
|
117 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
118 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
119 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
120 |
+
# install all needed dependencies.
|
121 |
+
#Pipfile.lock
|
122 |
+
|
123 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
124 |
+
__pypackages__/
|
125 |
+
|
126 |
+
# Celery stuff
|
127 |
+
celerybeat-schedule
|
128 |
+
celerybeat.pid
|
129 |
+
|
130 |
+
# SageMath parsed files
|
131 |
+
*.sage.py
|
132 |
+
|
133 |
+
# Environments
|
134 |
+
.env
|
135 |
+
.venv
|
136 |
+
env/
|
137 |
+
venv/
|
138 |
+
ENV/
|
139 |
+
env.bak/
|
140 |
+
venv.bak/
|
141 |
+
|
142 |
+
# Spyder project settings
|
143 |
+
.spyderproject
|
144 |
+
.spyproject
|
145 |
+
|
146 |
+
# Rope project settings
|
147 |
+
.ropeproject
|
148 |
+
|
149 |
+
# mkdocs documentation
|
150 |
+
/site
|
151 |
+
|
152 |
+
# mypy
|
153 |
+
.mypy_cache/
|
154 |
+
.dmypy.json
|
155 |
+
dmypy.json
|
156 |
+
|
157 |
+
# Pyre type checker
|
158 |
+
.pyre/
|
159 |
+
learnableearthparser/fast_sampler/_sampler.c
|
LICENSE.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright 2025, Dmytro Kotovenko, Olga Grebenkova, Björn Ommer
|
2 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted for non-commercial academic research and/or non-commercial personal use only provided that the following conditions are met:
|
3 |
+
|
4 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
5 |
+
|
6 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
7 |
+
|
8 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
9 |
+
|
10 |
+
Any use of this software beyond the above specified conditions requires a separate license. Please contact the copyright holders to discuss license terms.
|
11 |
+
|
12 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: EDGS
|
3 |
+
emoji: 🎥
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: blue
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.25.2
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
python_version: "3.10"
|
11 |
+
license: other
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
import tempfile
|
6 |
+
import argparse
|
7 |
+
import gradio as gr
|
8 |
+
import sys
|
9 |
+
import io
|
10 |
+
import subprocess
|
11 |
+
from PIL import Image
|
12 |
+
import numpy as np
|
13 |
+
from hydra import initialize, compose
|
14 |
+
import hydra
|
15 |
+
from omegaconf import OmegaConf
|
16 |
+
import time
|
17 |
+
|
18 |
+
def install_submodules():
|
19 |
+
subprocess.check_call(['pip', 'install', './submodules/RoMa'])
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
STATIC_FILE_SERVING_FOLDER = "./served_files"
|
24 |
+
MODEL_PATH = None
|
25 |
+
os.makedirs(STATIC_FILE_SERVING_FOLDER, exist_ok=True)
|
26 |
+
|
27 |
+
trainer = None
|
28 |
+
|
29 |
+
class Tee(io.TextIOBase):
|
30 |
+
|
31 |
+
def __init__(self, *streams):
|
32 |
+
self.streams = streams
|
33 |
+
|
34 |
+
def write(self, data):
|
35 |
+
for stream in self.streams:
|
36 |
+
stream.write(data)
|
37 |
+
return len(data)
|
38 |
+
|
39 |
+
def flush(self):
|
40 |
+
for stream in self.streams:
|
41 |
+
stream.flush()
|
42 |
+
|
43 |
+
|
44 |
+
def capture_logs(func, *args, **kwargs):
|
45 |
+
log_capture_string = io.StringIO()
|
46 |
+
tee = Tee(sys.__stdout__, log_capture_string)
|
47 |
+
|
48 |
+
with contextlib.redirect_stdout(tee):
|
49 |
+
result = func(*args, **kwargs)
|
50 |
+
|
51 |
+
return result, log_capture_string.getvalue()
|
52 |
+
|
53 |
+
|
54 |
+
@spaces.GPU(duration=350)
|
55 |
+
# Training Pipeline
|
56 |
+
def run_training_pipeline(scene_dir,
|
57 |
+
num_ref_views=16,
|
58 |
+
num_corrs_per_view=20000,
|
59 |
+
num_steps=1_000,
|
60 |
+
mode_toggle="Ours (EDGS)"):
|
61 |
+
with initialize(config_path="./configs", version_base="1.1"):
|
62 |
+
cfg = compose(config_name="train")
|
63 |
+
|
64 |
+
scene_name = os.path.basename(scene_dir)
|
65 |
+
model_output_dir = f"./outputs/{scene_name}_trained"
|
66 |
+
|
67 |
+
cfg.wandb.mode = "disabled"
|
68 |
+
cfg.gs.dataset.model_path = model_output_dir
|
69 |
+
cfg.gs.dataset.source_path = scene_dir
|
70 |
+
cfg.gs.dataset.images = "images"
|
71 |
+
|
72 |
+
cfg.gs.opt.TEST_CAM_IDX_TO_LOG = 12
|
73 |
+
cfg.train.gs_epochs = 30000
|
74 |
+
|
75 |
+
if mode_toggle=="Ours (EDGS)":
|
76 |
+
cfg.gs.opt.opacity_reset_interval = 1_000_000
|
77 |
+
cfg.train.reduce_opacity = True
|
78 |
+
cfg.train.no_densify = True
|
79 |
+
cfg.train.max_lr = True
|
80 |
+
|
81 |
+
cfg.init_wC.use = True
|
82 |
+
cfg.init_wC.matches_per_ref = num_corrs_per_view
|
83 |
+
cfg.init_wC.nns_per_ref = 1
|
84 |
+
cfg.init_wC.num_refs = num_ref_views
|
85 |
+
cfg.init_wC.add_SfM_init = False
|
86 |
+
cfg.init_wC.scaling_factor = 0.00077 * 2.
|
87 |
+
|
88 |
+
set_seed(cfg.seed)
|
89 |
+
os.makedirs(cfg.gs.dataset.model_path, exist_ok=True)
|
90 |
+
|
91 |
+
global trainer
|
92 |
+
global MODEL_PATH
|
93 |
+
generator3dgs = hydra.utils.instantiate(cfg.gs, do_train_test_split=False)
|
94 |
+
trainer = EDGSTrainer(GS=generator3dgs, training_config=cfg.gs.opt, device=cfg.device, log_wandb=cfg.wandb.mode != 'disabled')
|
95 |
+
|
96 |
+
# Disable evaluation and saving
|
97 |
+
trainer.saving_iterations = []
|
98 |
+
trainer.evaluate_iterations = []
|
99 |
+
|
100 |
+
# Initialize
|
101 |
+
trainer.timer.start()
|
102 |
+
start_time = time.time()
|
103 |
+
trainer.init_with_corr(cfg.init_wC, roma_model=roma_model)
|
104 |
+
time_for_init = time.time()-start_time
|
105 |
+
|
106 |
+
viewpoint_cams = trainer.GS.scene.getTrainCameras()
|
107 |
+
path_cameras = generate_fully_smooth_cameras_with_tsp(existing_cameras=viewpoint_cams,
|
108 |
+
n_selected=6,
|
109 |
+
n_points_per_segment=30,
|
110 |
+
closed=False)
|
111 |
+
path_cameras = path_cameras + path_cameras[::-1]
|
112 |
+
|
113 |
+
path_renderings = []
|
114 |
+
idx = 0
|
115 |
+
# Visualize after init
|
116 |
+
for _ in range(120):
|
117 |
+
with torch.no_grad():
|
118 |
+
viewpoint_cam = path_cameras[idx]
|
119 |
+
idx = (idx + 1) % len(path_cameras)
|
120 |
+
render_pkg = trainer.GS(viewpoint_cam)
|
121 |
+
image = render_pkg["render"]
|
122 |
+
image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
|
123 |
+
image_np = (image_np * 255).astype(np.uint8)
|
124 |
+
path_renderings.append(put_text_on_image(img=image_np,
|
125 |
+
text=f"Init stage.\nTime:{time_for_init:.3f}s. "))
|
126 |
+
path_renderings = path_renderings + [put_text_on_image(img=image_np, text=f"Start fitting.\nTime:{time_for_init:.3f}s. ")]*30
|
127 |
+
|
128 |
+
# Train and save visualizations during training.
|
129 |
+
start_time = time.time()
|
130 |
+
for _ in range(int(num_steps//10)):
|
131 |
+
with torch.no_grad():
|
132 |
+
viewpoint_cam = path_cameras[idx]
|
133 |
+
idx = (idx + 1) % len(path_cameras)
|
134 |
+
render_pkg = trainer.GS(viewpoint_cam)
|
135 |
+
image = render_pkg["render"]
|
136 |
+
image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
|
137 |
+
image_np = (image_np * 255).astype(np.uint8)
|
138 |
+
path_renderings.append(put_text_on_image(
|
139 |
+
img=image_np,
|
140 |
+
text=f"Fitting stage.\nTime:{time_for_init + time.time()-start_time:.3f}s. "))
|
141 |
+
|
142 |
+
cfg.train.gs_epochs = 10
|
143 |
+
trainer.train(cfg.train)
|
144 |
+
print(f"Time elapsed: {(time_for_init + time.time()-start_time):.2f}s.")
|
145 |
+
# if (cfg.init_wC.use == False) and (time_for_init + time.time()-start_time) > 60:
|
146 |
+
# break
|
147 |
+
final_time = time.time()
|
148 |
+
|
149 |
+
# Add static frame. To highlight we're done
|
150 |
+
path_renderings += [put_text_on_image(
|
151 |
+
img=image_np, text=f"Done.\nTime:{time_for_init + final_time -start_time:.3f}s. ")]*30
|
152 |
+
# Final rendering at the end.
|
153 |
+
for _ in range(len(path_cameras)):
|
154 |
+
with torch.no_grad():
|
155 |
+
viewpoint_cam = path_cameras[idx]
|
156 |
+
idx = (idx + 1) % len(path_cameras)
|
157 |
+
render_pkg = trainer.GS(viewpoint_cam)
|
158 |
+
image = render_pkg["render"]
|
159 |
+
image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
|
160 |
+
image_np = (image_np * 255).astype(np.uint8)
|
161 |
+
path_renderings.append(put_text_on_image(img=image_np,
|
162 |
+
text=f"Final result.\nTime:{time_for_init + final_time -start_time:.3f}s. "))
|
163 |
+
|
164 |
+
trainer.save_model()
|
165 |
+
final_video_path = os.path.join(STATIC_FILE_SERVING_FOLDER, f"{scene_name}_final.mp4")
|
166 |
+
save_numpy_frames_as_mp4(frames=path_renderings, output_path=final_video_path, fps=30, center_crop=0.85)
|
167 |
+
MODEL_PATH = cfg.gs.dataset.model_path
|
168 |
+
ply_path = os.path.join(cfg.gs.dataset.model_path, f"point_cloud/iteration_{trainer.gs_step}/point_cloud.ply")
|
169 |
+
shutil.copy(ply_path, os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply"))
|
170 |
+
|
171 |
+
return final_video_path, ply_path
|
172 |
+
|
173 |
+
# Gradio Interface
|
174 |
+
def gradio_interface(input_path, num_ref_views, num_corrs, num_steps):
|
175 |
+
images, scene_dir = run_full_pipeline(input_path, num_ref_views, num_corrs, max_size=1024)
|
176 |
+
shutil.copytree(scene_dir, STATIC_FILE_SERVING_FOLDER+'/scene_colmaped', dirs_exist_ok=True)
|
177 |
+
(final_video_path, ply_path), log_output = capture_logs(run_training_pipeline,
|
178 |
+
scene_dir,
|
179 |
+
num_ref_views,
|
180 |
+
num_corrs,
|
181 |
+
num_steps
|
182 |
+
)
|
183 |
+
images_rgb = [img[:, :, ::-1] for img in images]
|
184 |
+
return images_rgb, final_video_path, scene_dir, ply_path, log_output
|
185 |
+
|
186 |
+
# Dummy Render Functions
|
187 |
+
@spaces.GPU(duration=60)
|
188 |
+
def render_all_views(scene_dir):
|
189 |
+
viewpoint_cams = trainer.GS.scene.getTrainCameras()
|
190 |
+
path_cameras = generate_fully_smooth_cameras_with_tsp(existing_cameras=viewpoint_cams,
|
191 |
+
n_selected=8,
|
192 |
+
n_points_per_segment=60,
|
193 |
+
closed=False)
|
194 |
+
path_cameras = path_cameras + path_cameras[::-1]
|
195 |
+
|
196 |
+
path_renderings = []
|
197 |
+
with torch.no_grad():
|
198 |
+
for viewpoint_cam in path_cameras:
|
199 |
+
render_pkg = trainer.GS(viewpoint_cam)
|
200 |
+
image = render_pkg["render"]
|
201 |
+
image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
|
202 |
+
image_np = (image_np * 255).astype(np.uint8)
|
203 |
+
path_renderings.append(image_np)
|
204 |
+
save_numpy_frames_as_mp4(frames=path_renderings,
|
205 |
+
output_path=os.path.join(STATIC_FILE_SERVING_FOLDER, "render_all_views.mp4"),
|
206 |
+
fps=30,
|
207 |
+
center_crop=0.85)
|
208 |
+
|
209 |
+
return os.path.join(STATIC_FILE_SERVING_FOLDER, "render_all_views.mp4")
|
210 |
+
|
211 |
+
@spaces.GPU(duration=60)
|
212 |
+
def render_circular_path(scene_dir):
|
213 |
+
viewpoint_cams = trainer.GS.scene.getTrainCameras()
|
214 |
+
path_cameras = generate_circular_camera_path(existing_cameras=viewpoint_cams,
|
215 |
+
N=240,
|
216 |
+
radius_scale=0.65,
|
217 |
+
d=0)
|
218 |
+
|
219 |
+
path_renderings = []
|
220 |
+
with torch.no_grad():
|
221 |
+
for viewpoint_cam in path_cameras:
|
222 |
+
render_pkg = trainer.GS(viewpoint_cam)
|
223 |
+
image = render_pkg["render"]
|
224 |
+
image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
|
225 |
+
image_np = (image_np * 255).astype(np.uint8)
|
226 |
+
path_renderings.append(image_np)
|
227 |
+
save_numpy_frames_as_mp4(frames=path_renderings,
|
228 |
+
output_path=os.path.join(STATIC_FILE_SERVING_FOLDER, "render_circular_path.mp4"),
|
229 |
+
fps=30,
|
230 |
+
center_crop=0.85)
|
231 |
+
|
232 |
+
return os.path.join(STATIC_FILE_SERVING_FOLDER, "render_circular_path.mp4")
|
233 |
+
|
234 |
+
# Download Functions
|
235 |
+
def download_cameras():
|
236 |
+
path = os.path.join(MODEL_PATH, "cameras.json")
|
237 |
+
return f"[📥 Download Cameras.json](file={path})"
|
238 |
+
|
239 |
+
def download_model():
|
240 |
+
path = os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply")
|
241 |
+
return f"[📥 Download Pretrained Model (.ply)](file={path})"
|
242 |
+
|
243 |
+
# Full pipeline helpers
|
244 |
+
def run_full_pipeline(input_path, num_ref_views, num_corrs, max_size=1024):
|
245 |
+
tmpdirname = tempfile.mkdtemp()
|
246 |
+
scene_dir = os.path.join(tmpdirname, "scene")
|
247 |
+
os.makedirs(scene_dir, exist_ok=True)
|
248 |
+
|
249 |
+
selected_frames = process_input(input_path, num_ref_views, scene_dir, max_size)
|
250 |
+
run_colmap_on_scene(scene_dir)
|
251 |
+
|
252 |
+
return selected_frames, scene_dir
|
253 |
+
|
254 |
+
# Preprocess Input
|
255 |
+
def process_input(input_path, num_ref_views, output_dir, max_size=1024):
|
256 |
+
if isinstance(input_path, (str, os.PathLike)):
|
257 |
+
if os.path.isdir(input_path):
|
258 |
+
frames = []
|
259 |
+
for img_file in sorted(os.listdir(input_path)):
|
260 |
+
if img_file.lower().endswith(('jpg', 'jpeg', 'png')):
|
261 |
+
img = Image.open(os.path.join(output_dir, img_file)).convert('RGB')
|
262 |
+
img.thumbnail((1024, 1024))
|
263 |
+
frames.append(np.array(img))
|
264 |
+
else:
|
265 |
+
frames = read_video_frames(video_input=input_path, max_size=max_size)
|
266 |
+
else:
|
267 |
+
frames = read_video_frames(video_input=input_path, max_size=max_size)
|
268 |
+
|
269 |
+
frames_scores = preprocess_frames(frames)
|
270 |
+
selected_frames_indices = select_optimal_frames(scores=frames_scores, k=min(num_ref_views, len(frames)))
|
271 |
+
selected_frames = [frames[frame_idx] for frame_idx in selected_frames_indices]
|
272 |
+
|
273 |
+
save_frames_to_scene_dir(frames=selected_frames, scene_dir=output_dir)
|
274 |
+
return selected_frames
|
275 |
+
|
276 |
+
@spaces.GPU(duration=150)
|
277 |
+
def preprocess_input(input_path, num_ref_views, max_size=1024):
|
278 |
+
tmpdirname = tempfile.mkdtemp()
|
279 |
+
scene_dir = os.path.join(tmpdirname, "scene")
|
280 |
+
os.makedirs(scene_dir, exist_ok=True)
|
281 |
+
selected_frames = process_input(input_path, num_ref_views, scene_dir, max_size)
|
282 |
+
run_colmap_on_scene(scene_dir)
|
283 |
+
return selected_frames, scene_dir
|
284 |
+
|
285 |
+
def start_training(scene_dir, num_ref_views, num_corrs, num_steps):
|
286 |
+
return capture_logs(run_training_pipeline, scene_dir, num_ref_views, num_corrs, num_steps)
|
287 |
+
|
288 |
+
# Gradio App
|
289 |
+
with gr.Blocks() as demo:
|
290 |
+
with gr.Row():
|
291 |
+
with gr.Column(scale=6):
|
292 |
+
gr.Markdown("""
|
293 |
+
## <span style='font-size: 20px;'>📄 EDGS: Eliminating Densification for Efficient Convergence of 3DGS</span>
|
294 |
+
🔗 <a href='https://compvis.github.io/EDGS' target='_blank'>Project Page</a>
|
295 |
+
""", elem_id="header")
|
296 |
+
|
297 |
+
gr.Markdown("""
|
298 |
+
### <span style='font-size: 22px;'>🛠️ How to Use This Demo</span>
|
299 |
+
|
300 |
+
1. Upload a **front-facing video** or **a folder of images** of a **static** scene.
|
301 |
+
2. Use the sliders to configure the number of reference views, correspondences, and optimization steps.
|
302 |
+
3. First press on preprocess Input to extract frames from video(for videos) and COLMAP frames.
|
303 |
+
4.Then click **🚀 Start Reconstruction** to actually launch the reconstruction pipeline.
|
304 |
+
5. Watch the training visualization and explore the 3D model.
|
305 |
+
‼️ **If you see nothing in the 3D model viewer**, try rotating or zooming — sometimes the initial camera orientation is off.
|
306 |
+
|
307 |
+
|
308 |
+
✅ Best for scenes with small camera motion.
|
309 |
+
❗ For full 360° or large-scale scenes, we recommend the Colab version (see project page).
|
310 |
+
""", elem_id="quickstart")
|
311 |
+
scene_dir_state = gr.State()
|
312 |
+
ply_model_state = gr.State()
|
313 |
+
with gr.Row():
|
314 |
+
with gr.Column(scale=2):
|
315 |
+
input_file = gr.File(label="Upload Video or Images",
|
316 |
+
file_types=[".mp4", ".avi", ".mov", ".png", ".jpg", ".jpeg"],
|
317 |
+
file_count="multiple")
|
318 |
+
gr.Examples(
|
319 |
+
examples = [
|
320 |
+
[["assets/examples/video_bakery.mp4"]],
|
321 |
+
[["assets/examples/video_flowers.mp4"]],
|
322 |
+
[["assets/examples/video_fruits.mp4"]],
|
323 |
+
[["assets/examples/video_plant.mp4"]],
|
324 |
+
[["assets/examples/video_salad.mp4"]],
|
325 |
+
[["assets/examples/video_tram.mp4"]],
|
326 |
+
[["assets/examples/video_tulips.mp4"]]
|
327 |
+
],
|
328 |
+
inputs=[input_file],
|
329 |
+
label="🎞️ ALternatively, try an Example Video",
|
330 |
+
examples_per_page=4
|
331 |
+
)
|
332 |
+
ref_slider = gr.Slider(4, 32, value=16, step=1, label="Number of Reference Views")
|
333 |
+
corr_slider = gr.Slider(5000, 30000, value=20000, step=1000, label="Correspondences per Reference View")
|
334 |
+
fit_steps_slider = gr.Slider(100, 5000, value=400, step=100, label="Number of optimization steps")
|
335 |
+
preprocess_button = gr.Button("📸 Preprocess Input")
|
336 |
+
start_button = gr.Button("🚀 Start Reconstruction", interactive=False)
|
337 |
+
gallery = gr.Gallery(label="Selected Reference Views", columns=4, height=300)
|
338 |
+
|
339 |
+
with gr.Column(scale=3):
|
340 |
+
gr.Markdown("### 🏋️ Training Visualization")
|
341 |
+
video_output = gr.Video(label="Training Video", autoplay=True)
|
342 |
+
render_all_views_button = gr.Button("🎥 Render All-Views Path")
|
343 |
+
render_circular_path_button = gr.Button("🎥 Render Circular Path")
|
344 |
+
rendered_video_output = gr.Video(label="Rendered Video", autoplay=True)
|
345 |
+
with gr.Column(scale=5):
|
346 |
+
gr.Markdown("### 🌐 Final 3D Model")
|
347 |
+
model3d_viewer = gr.Model3D(label="3D Model Viewer")
|
348 |
+
|
349 |
+
gr.Markdown("### 📦 Output Files")
|
350 |
+
with gr.Row(height=50):
|
351 |
+
with gr.Column():
|
352 |
+
#gr.Markdown(value=f"[📥 Download .ply](file/point_cloud_final.ply)")
|
353 |
+
download_cameras_button = gr.Button("📥 Download Cameras.json")
|
354 |
+
download_cameras_file = gr.File(label="📄 Cameras.json")
|
355 |
+
with gr.Column():
|
356 |
+
download_model_button = gr.Button("📥 Download Pretrained Model (.ply)")
|
357 |
+
download_model_file = gr.File(label="📄 Pretrained Model (.ply)")
|
358 |
+
|
359 |
+
log_output_box = gr.Textbox(label="🖥️ Log", lines=10, interactive=False)
|
360 |
+
|
361 |
+
def on_preprocess_click(input_file, num_ref_views):
|
362 |
+
images, scene_dir = preprocess_input(input_file, num_ref_views)
|
363 |
+
return gr.update(value=[x[...,::-1] for x in images]), scene_dir, gr.update(interactive=True)
|
364 |
+
|
365 |
+
def on_start_click(scene_dir, num_ref_views, num_corrs, num_steps):
|
366 |
+
(video_path, ply_path), logs = start_training(scene_dir, num_ref_views, num_corrs, num_steps)
|
367 |
+
return video_path, ply_path, logs
|
368 |
+
|
369 |
+
preprocess_button.click(
|
370 |
+
fn=on_preprocess_click,
|
371 |
+
inputs=[input_file, ref_slider],
|
372 |
+
outputs=[gallery, scene_dir_state, start_button]
|
373 |
+
)
|
374 |
+
|
375 |
+
start_button.click(
|
376 |
+
fn=on_start_click,
|
377 |
+
inputs=[scene_dir_state, ref_slider, corr_slider, fit_steps_slider],
|
378 |
+
outputs=[video_output, model3d_viewer, log_output_box]
|
379 |
+
)
|
380 |
+
|
381 |
+
render_all_views_button.click(fn=render_all_views, inputs=[scene_dir_state], outputs=[rendered_video_output])
|
382 |
+
render_circular_path_button.click(fn=render_circular_path, inputs=[scene_dir_state], outputs=[rendered_video_output])
|
383 |
+
|
384 |
+
download_cameras_button.click(fn=lambda: os.path.join(MODEL_PATH, "cameras.json"), inputs=[], outputs=[download_cameras_file])
|
385 |
+
download_model_button.click(fn=lambda: os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply"), inputs=[], outputs=[download_model_file])
|
386 |
+
|
387 |
+
|
388 |
+
gr.Markdown("""
|
389 |
+
---
|
390 |
+
### <span style='font-size: 20px;'>📖 Detailed Overview</span>
|
391 |
+
|
392 |
+
If you uploaded a video, it will be automatically cut into a smaller number of frames (default: 16).
|
393 |
+
|
394 |
+
The model pipeline:
|
395 |
+
1. 🧠 Runs PyCOLMAP to estimate camera intrinsics & poses (~3–7 seconds for <16 images).
|
396 |
+
2. 🔁 Computes 2D-2D correspondences between views. More correspondences generally improve quality.
|
397 |
+
3. 🔧 Optimizes a 3D Gaussian Splatting model for several steps.
|
398 |
+
|
399 |
+
### 🎥 Training Visualization
|
400 |
+
You will see a visualization of the entire training process in the "Training Video" pane.
|
401 |
+
|
402 |
+
### 🌀 Rendering & 3D Model
|
403 |
+
- Render the scene from a circular path of novel views.
|
404 |
+
- Or from camera views close to the original input.
|
405 |
+
|
406 |
+
The 3D model is shown in the right viewer. You can explore it interactively:
|
407 |
+
- On PC: WASD keys, arrow keys, and mouse clicks
|
408 |
+
- On mobile: pan and pinch to zoom
|
409 |
+
|
410 |
+
🕒 Note: the 3D viewer takes a few extra seconds (~5s) to display after training ends.
|
411 |
+
|
412 |
+
---
|
413 |
+
Preloaded models coming soon. (TODO)
|
414 |
+
""", elem_id="details")
|
415 |
+
|
416 |
+
|
417 |
+
|
418 |
+
|
419 |
+
if __name__ == "__main__":
|
420 |
+
install_submodules()
|
421 |
+
from source.utils_aux import set_seed
|
422 |
+
from source.utils_preprocess import read_video_frames, preprocess_frames, select_optimal_frames, save_frames_to_scene_dir, run_colmap_on_scene
|
423 |
+
from source.trainer import EDGSTrainer
|
424 |
+
from source.visualization import generate_circular_camera_path, save_numpy_frames_as_mp4, generate_fully_smooth_cameras_with_tsp, put_text_on_image
|
425 |
+
# Init RoMA model:
|
426 |
+
sys.path.append('../submodules/RoMa')
|
427 |
+
from romatch import roma_outdoor, roma_indoor
|
428 |
+
|
429 |
+
roma_model = roma_indoor(device="cpu")
|
430 |
+
roma_model = roma_model.to("cuda")
|
431 |
+
roma_model.upsample_preds = False
|
432 |
+
roma_model.symmetric = False
|
433 |
+
|
434 |
+
demo.launch(share=True)
|
assets/examples/video_bakery.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b65c813f2ef9637579350e145fdceed333544d933278be5613c6d49468f4eab0
|
3 |
+
size 6362238
|
assets/examples/video_flowers.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c81a7c28e0d59bad38d5a45f6bfa83b80990d1fb78a06f82137e8b57ec38e62b
|
3 |
+
size 6466943
|
assets/examples/video_fruits.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ac3b937e155d1965a314b478e940f6fd93e90371d3bf7d62d3225d840fe8e126
|
3 |
+
size 3356915
|
assets/examples/video_plant.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e67a40e62de9aacf0941d8cf33a2dd08d256fe60d0fc58f60426c251f9d8abd8
|
3 |
+
size 13023885
|
assets/examples/video_salad.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:367d86071a201c124383f0a076ce644b7f86b9c2fbce6aa595c4989ebf259bfb
|
3 |
+
size 8774427
|
assets/examples/video_tram.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:298c297155d3f52edcffcb6b6f9910c992b98ecfb93cfaf8fb64fb340aba1dae
|
3 |
+
size 4697915
|
assets/examples/video_tulips.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c19942319b8f7e33bb03cb4a39b11797fe38572bf7157af37471b7c8573fb495
|
3 |
+
size 7298210
|
assets/video_fruits_ours_full.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9c5b113566d3a083b81360b549ed89f70d5e81739f83e182518f6906811311a2
|
3 |
+
size 14839197
|
configs/gs/base.yaml
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: source.networks.Warper3DGS
|
2 |
+
|
3 |
+
verbose: True
|
4 |
+
viewpoint_stack: !!null
|
5 |
+
sh_degree: 3
|
6 |
+
|
7 |
+
opt:
|
8 |
+
iterations: 30000
|
9 |
+
position_lr_init: 0.00016
|
10 |
+
position_lr_final: 1.6e-06
|
11 |
+
position_lr_delay_mult: 0.01
|
12 |
+
position_lr_max_steps: 30000
|
13 |
+
feature_lr: 0.0025
|
14 |
+
opacity_lr: 0.025
|
15 |
+
scaling_lr: 0.005
|
16 |
+
rotation_lr: 0.001
|
17 |
+
percent_dense: 0.01
|
18 |
+
lambda_dssim: 0.2
|
19 |
+
densification_interval: 100
|
20 |
+
opacity_reset_interval: 30000
|
21 |
+
densify_from_iter: 500
|
22 |
+
densify_until_iter: 15000
|
23 |
+
densify_grad_threshold: 0.0002
|
24 |
+
random_background: false
|
25 |
+
save_iterations: [3000, 7000, 15000, 30000]
|
26 |
+
batch_size: 64
|
27 |
+
exposure_lr_init: 0.01
|
28 |
+
exposure_lr_final: 0.0001
|
29 |
+
exposure_lr_delay_steps: 0
|
30 |
+
exposure_lr_delay_mult: 0.0
|
31 |
+
|
32 |
+
TRAIN_CAM_IDX_TO_LOG: 50
|
33 |
+
TEST_CAM_IDX_TO_LOG: 10
|
34 |
+
|
35 |
+
pipe:
|
36 |
+
convert_SHs_python: False
|
37 |
+
compute_cov3D_python: False
|
38 |
+
debug: False
|
39 |
+
antialiasing: False
|
40 |
+
|
41 |
+
dataset:
|
42 |
+
densify_until_iter: 15000
|
43 |
+
source_path: '' #path to dataset
|
44 |
+
model_path: '' #path to logs
|
45 |
+
images: images
|
46 |
+
resolution: -1
|
47 |
+
white_background: false
|
48 |
+
data_device: cuda
|
49 |
+
eval: false
|
50 |
+
depths: ""
|
51 |
+
train_test_exp: False
|
configs/train.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- gs: base
|
3 |
+
- _self_
|
4 |
+
|
5 |
+
seed: 228
|
6 |
+
|
7 |
+
wandb:
|
8 |
+
mode: "online" # "disabled" for no logging
|
9 |
+
entity: "3dcorrespondence"
|
10 |
+
project: "Adv3DGS"
|
11 |
+
group: null
|
12 |
+
name: null
|
13 |
+
tag: "debug"
|
14 |
+
|
15 |
+
train:
|
16 |
+
gs_epochs: 0 # number of 3dgs iterations
|
17 |
+
reduce_opacity: True
|
18 |
+
no_densify: False # if True, the model will not be densified
|
19 |
+
max_lr: True
|
20 |
+
|
21 |
+
load:
|
22 |
+
gs: null #path to 3dgs checkpoint
|
23 |
+
gs_step: null #number of iterations, e.g. 7000
|
24 |
+
|
25 |
+
device: "cuda:0"
|
26 |
+
verbose: true
|
27 |
+
|
28 |
+
init_wC:
|
29 |
+
use: True # use EDGS
|
30 |
+
matches_per_ref: 15_000 # number of matches per reference
|
31 |
+
num_refs: 180 # number of reference images
|
32 |
+
nns_per_ref: 3 # number of nearest neighbors per reference
|
33 |
+
scaling_factor: 0.001
|
34 |
+
proj_err_tolerance: 0.01
|
35 |
+
roma_model: "outdoors" # you can change this to "indoors" or "outdoors"
|
36 |
+
add_SfM_init : False
|
37 |
+
|
38 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu124
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
torchaudio
|
5 |
+
|
6 |
+
|
7 |
+
# Required libraries from pip
|
8 |
+
Pillow
|
9 |
+
huggingface_hub
|
10 |
+
einops
|
11 |
+
safetensors
|
12 |
+
sympy==1.13.1
|
13 |
+
wandb
|
14 |
+
hydra-core
|
15 |
+
tqdm
|
16 |
+
torchmetrics
|
17 |
+
lpips
|
18 |
+
matplotlib
|
19 |
+
rich
|
20 |
+
plyfile
|
21 |
+
imageio
|
22 |
+
imageio-ffmpeg
|
23 |
+
numpy==1.26.4 # Match conda-installed version
|
24 |
+
opencv-python
|
25 |
+
pycolmap
|
26 |
+
moviepy
|
27 |
+
plotly
|
28 |
+
scikit-learn
|
29 |
+
ffmpeg
|
30 |
+
|
31 |
+
https://huggingface.co/spaces/magistrkoljan/test/resolve/main/wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl?download=true
|
32 |
+
https://huggingface.co/spaces/magistrkoljan/test/resolve/main/wheels/simple_knn-0.0.0-cp310-cp310-linux_x86_64.whl?download=true
|
source/EDGS.code-workspace
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"folders": [
|
3 |
+
{
|
4 |
+
"path": ".."
|
5 |
+
},
|
6 |
+
{
|
7 |
+
"path": "../../../../.."
|
8 |
+
}
|
9 |
+
],
|
10 |
+
"settings": {}
|
11 |
+
}
|
source/__init__.py
ADDED
File without changes
|
source/corr_init.py
ADDED
@@ -0,0 +1,682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('../')
|
3 |
+
sys.path.append("../submodules")
|
4 |
+
sys.path.append('../submodules/RoMa')
|
5 |
+
|
6 |
+
from matplotlib import pyplot as plt
|
7 |
+
from PIL import Image
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
#from tqdm import tqdm_notebook as tqdm
|
12 |
+
from tqdm import tqdm
|
13 |
+
from scipy.cluster.vq import kmeans, vq
|
14 |
+
from scipy.spatial.distance import cdist
|
15 |
+
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from romatch import roma_outdoor, roma_indoor
|
18 |
+
from utils.sh_utils import RGB2SH
|
19 |
+
from romatch.utils import get_tuple_transform_ops
|
20 |
+
|
21 |
+
|
22 |
+
def pairwise_distances(matrix):
|
23 |
+
"""
|
24 |
+
Computes the pairwise Euclidean distances between all vectors in the input matrix.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
torch.Tensor: Pairwise distance matrix of shape [N, N].
|
31 |
+
"""
|
32 |
+
# Compute squared pairwise distances
|
33 |
+
squared_diff = torch.cdist(matrix, matrix, p=2)
|
34 |
+
return squared_diff
|
35 |
+
|
36 |
+
|
37 |
+
def k_closest_vectors(matrix, k):
|
38 |
+
"""
|
39 |
+
Finds the k-closest vectors for each vector in the input matrix based on Euclidean distance.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
|
43 |
+
k (int): Number of closest vectors to return for each vector.
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
torch.Tensor: Indices of the k-closest vectors for each vector, excluding the vector itself.
|
47 |
+
"""
|
48 |
+
# Compute pairwise distances
|
49 |
+
distances = pairwise_distances(matrix)
|
50 |
+
|
51 |
+
# For each vector, sort distances and get the indices of the k-closest vectors (excluding itself)
|
52 |
+
# Set diagonal distances to infinity to exclude the vector itself from the nearest neighbors
|
53 |
+
distances.fill_diagonal_(float('inf'))
|
54 |
+
|
55 |
+
# Get the indices of the k smallest distances (k-closest vectors)
|
56 |
+
_, indices = torch.topk(distances, k, largest=False, dim=1)
|
57 |
+
|
58 |
+
return indices
|
59 |
+
|
60 |
+
|
61 |
+
def select_cameras_kmeans(cameras, K):
|
62 |
+
"""
|
63 |
+
Selects K cameras from a set using K-means clustering.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
cameras: NumPy array of shape (N, 16), representing N cameras with their 4x4 homogeneous matrices flattened.
|
67 |
+
K: Number of clusters (cameras to select).
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
selected_indices: List of indices of the cameras closest to the cluster centers.
|
71 |
+
"""
|
72 |
+
# Ensure input is a NumPy array
|
73 |
+
if not isinstance(cameras, np.ndarray):
|
74 |
+
cameras = np.asarray(cameras)
|
75 |
+
|
76 |
+
if cameras.shape[1] != 16:
|
77 |
+
raise ValueError("Each camera must have 16 values corresponding to a flattened 4x4 matrix.")
|
78 |
+
|
79 |
+
# Perform K-means clustering
|
80 |
+
cluster_centers, _ = kmeans(cameras, K)
|
81 |
+
|
82 |
+
# Assign each camera to a cluster and find distances to cluster centers
|
83 |
+
cluster_assignments, _ = vq(cameras, cluster_centers)
|
84 |
+
|
85 |
+
# Find the camera nearest to each cluster center
|
86 |
+
selected_indices = []
|
87 |
+
for k in range(K):
|
88 |
+
cluster_members = cameras[cluster_assignments == k]
|
89 |
+
distances = cdist([cluster_centers[k]], cluster_members)[0]
|
90 |
+
nearest_camera_idx = np.where(cluster_assignments == k)[0][np.argmin(distances)]
|
91 |
+
selected_indices.append(nearest_camera_idx)
|
92 |
+
|
93 |
+
return selected_indices
|
94 |
+
|
95 |
+
|
96 |
+
def compute_warp_and_confidence(viewpoint_cam1, viewpoint_cam2, roma_model, device="cuda", verbose=False, output_dict={}):
|
97 |
+
"""
|
98 |
+
Computes the warp and confidence between two viewpoint cameras using the roma_model.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
viewpoint_cam1: Source viewpoint camera.
|
102 |
+
viewpoint_cam2: Target viewpoint camera.
|
103 |
+
roma_model: Pre-trained Roma model for correspondence matching.
|
104 |
+
device: Device to run the computation on.
|
105 |
+
verbose: If True, displays the images.
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
certainty: Confidence tensor.
|
109 |
+
warp: Warp tensor.
|
110 |
+
imB: Processed image B as numpy array.
|
111 |
+
"""
|
112 |
+
# Prepare images
|
113 |
+
imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
|
114 |
+
imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
|
115 |
+
imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
|
116 |
+
imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
|
117 |
+
|
118 |
+
if verbose:
|
119 |
+
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(16, 8))
|
120 |
+
cax1 = ax[0].imshow(imA)
|
121 |
+
ax[0].set_title("Image 1")
|
122 |
+
cax2 = ax[1].imshow(imB)
|
123 |
+
ax[1].set_title("Image 2")
|
124 |
+
fig.colorbar(cax1, ax=ax[0])
|
125 |
+
fig.colorbar(cax2, ax=ax[1])
|
126 |
+
|
127 |
+
for axis in ax:
|
128 |
+
axis.axis('off')
|
129 |
+
# Save the figure into the dictionary
|
130 |
+
output_dict[f'image_pair'] = fig
|
131 |
+
|
132 |
+
# Transform images
|
133 |
+
ws, hs = roma_model.w_resized, roma_model.h_resized
|
134 |
+
test_transform = get_tuple_transform_ops(resize=(hs, ws), normalize=True)
|
135 |
+
im_A, im_B = test_transform((imA, imB))
|
136 |
+
batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
|
137 |
+
|
138 |
+
# Forward pass through Roma model
|
139 |
+
corresps = roma_model.forward(batch) if not roma_model.symmetric else roma_model.forward_symmetric(batch)
|
140 |
+
finest_scale = 1
|
141 |
+
hs, ws = roma_model.upsample_res if roma_model.upsample_preds else (hs, ws)
|
142 |
+
|
143 |
+
# Process certainty and warp
|
144 |
+
certainty = corresps[finest_scale]["certainty"]
|
145 |
+
im_A_to_im_B = corresps[finest_scale]["flow"]
|
146 |
+
if roma_model.attenuate_cert:
|
147 |
+
low_res_certainty = F.interpolate(
|
148 |
+
corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
|
149 |
+
)
|
150 |
+
certainty -= 0.5 * low_res_certainty * (low_res_certainty < 0)
|
151 |
+
|
152 |
+
# Upsample predictions if needed
|
153 |
+
if roma_model.upsample_preds:
|
154 |
+
im_A_to_im_B = F.interpolate(
|
155 |
+
im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
|
156 |
+
)
|
157 |
+
certainty = F.interpolate(
|
158 |
+
certainty, size=(hs, ws), align_corners=False, mode="bilinear"
|
159 |
+
)
|
160 |
+
|
161 |
+
# Convert predictions to final format
|
162 |
+
im_A_to_im_B = im_A_to_im_B.permute(0, 2, 3, 1)
|
163 |
+
im_A_coords = torch.stack(torch.meshgrid(
|
164 |
+
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
|
165 |
+
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
|
166 |
+
indexing='ij'
|
167 |
+
), dim=0).permute(1, 2, 0).unsqueeze(0).expand(im_A_to_im_B.size(0), -1, -1, -1)
|
168 |
+
|
169 |
+
warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
|
170 |
+
certainty = certainty.sigmoid()
|
171 |
+
|
172 |
+
return certainty[0, 0], warp[0], np.array(imB)
|
173 |
+
|
174 |
+
|
175 |
+
def resize_batch(tensors_3d, tensors_4d, target_shape):
|
176 |
+
"""
|
177 |
+
Resizes a batch of tensors with shapes [B, H, W] and [B, H, W, 4] to the target spatial dimensions.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
tensors_3d: Tensor of shape [B, H, W].
|
181 |
+
tensors_4d: Tensor of shape [B, H, W, 4].
|
182 |
+
target_shape: Tuple (target_H, target_W) specifying the target spatial dimensions.
|
183 |
+
|
184 |
+
Returns:
|
185 |
+
resized_tensors_3d: Tensor of shape [B, target_H, target_W].
|
186 |
+
resized_tensors_4d: Tensor of shape [B, target_H, target_W, 4].
|
187 |
+
"""
|
188 |
+
target_H, target_W = target_shape
|
189 |
+
|
190 |
+
# Resize [B, H, W] tensor
|
191 |
+
resized_tensors_3d = F.interpolate(
|
192 |
+
tensors_3d.unsqueeze(1), size=(target_H, target_W), mode="bilinear", align_corners=False
|
193 |
+
).squeeze(1)
|
194 |
+
|
195 |
+
# Resize [B, H, W, 4] tensor
|
196 |
+
B, _, _, C = tensors_4d.shape
|
197 |
+
resized_tensors_4d = F.interpolate(
|
198 |
+
tensors_4d.permute(0, 3, 1, 2), size=(target_H, target_W), mode="bilinear", align_corners=False
|
199 |
+
).permute(0, 2, 3, 1)
|
200 |
+
|
201 |
+
return resized_tensors_3d, resized_tensors_4d
|
202 |
+
|
203 |
+
|
204 |
+
def aggregate_confidences_and_warps(viewpoint_stack, closest_indices, roma_model, source_idx, verbose=False, output_dict={}):
|
205 |
+
"""
|
206 |
+
Aggregates confidences and warps by iterating over the nearest neighbors of the source viewpoint.
|
207 |
+
|
208 |
+
Args:
|
209 |
+
viewpoint_stack: Stack of viewpoint cameras.
|
210 |
+
closest_indices: Indices of the nearest neighbors for each viewpoint.
|
211 |
+
roma_model: Pre-trained Roma model.
|
212 |
+
source_idx: Index of the source viewpoint.
|
213 |
+
verbose: If True, displays intermediate results.
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
certainties_max: Aggregated maximum confidences.
|
217 |
+
warps_max: Aggregated warps corresponding to maximum confidences.
|
218 |
+
certainties_max_idcs: Pixel-wise index of the image from which we taken the best matching.
|
219 |
+
imB_compound: List of the neighboring images.
|
220 |
+
"""
|
221 |
+
certainties_all, warps_all, imB_compound = [], [], []
|
222 |
+
|
223 |
+
for nn in tqdm(closest_indices[source_idx]):
|
224 |
+
|
225 |
+
viewpoint_cam1 = viewpoint_stack[source_idx]
|
226 |
+
viewpoint_cam2 = viewpoint_stack[nn]
|
227 |
+
|
228 |
+
certainty, warp, imB = compute_warp_and_confidence(viewpoint_cam1, viewpoint_cam2, roma_model, verbose=verbose, output_dict=output_dict)
|
229 |
+
certainties_all.append(certainty)
|
230 |
+
warps_all.append(warp)
|
231 |
+
imB_compound.append(imB)
|
232 |
+
|
233 |
+
certainties_all = torch.stack(certainties_all, dim=0)
|
234 |
+
target_shape = imB_compound[0].shape[:2]
|
235 |
+
if verbose:
|
236 |
+
print("certainties_all.shape:", certainties_all.shape)
|
237 |
+
print("torch.stack(warps_all, dim=0).shape:", torch.stack(warps_all, dim=0).shape)
|
238 |
+
print("target_shape:", target_shape)
|
239 |
+
|
240 |
+
certainties_all_resized, warps_all_resized = resize_batch(certainties_all,
|
241 |
+
torch.stack(warps_all, dim=0),
|
242 |
+
target_shape
|
243 |
+
)
|
244 |
+
|
245 |
+
if verbose:
|
246 |
+
print("warps_all_resized.shape:", warps_all_resized.shape)
|
247 |
+
for n, cert in enumerate(certainties_all):
|
248 |
+
fig, ax = plt.subplots()
|
249 |
+
cax = ax.imshow(cert.cpu().numpy(), cmap='viridis')
|
250 |
+
fig.colorbar(cax, ax=ax)
|
251 |
+
ax.set_title("Pixel-wise Confidence")
|
252 |
+
output_dict[f'certainty_{n}'] = fig
|
253 |
+
|
254 |
+
for n, warp in enumerate(warps_all):
|
255 |
+
fig, ax = plt.subplots()
|
256 |
+
cax = ax.imshow(warp.cpu().numpy()[:, :, :3], cmap='viridis')
|
257 |
+
fig.colorbar(cax, ax=ax)
|
258 |
+
ax.set_title("Pixel-wise warp")
|
259 |
+
output_dict[f'warp_resized_{n}'] = fig
|
260 |
+
|
261 |
+
for n, cert in enumerate(certainties_all_resized):
|
262 |
+
fig, ax = plt.subplots()
|
263 |
+
cax = ax.imshow(cert.cpu().numpy(), cmap='viridis')
|
264 |
+
fig.colorbar(cax, ax=ax)
|
265 |
+
ax.set_title("Pixel-wise Confidence resized")
|
266 |
+
output_dict[f'certainty_resized_{n}'] = fig
|
267 |
+
|
268 |
+
for n, warp in enumerate(warps_all_resized):
|
269 |
+
fig, ax = plt.subplots()
|
270 |
+
cax = ax.imshow(warp.cpu().numpy()[:, :, :3], cmap='viridis')
|
271 |
+
fig.colorbar(cax, ax=ax)
|
272 |
+
ax.set_title("Pixel-wise warp resized")
|
273 |
+
output_dict[f'warp_resized_{n}'] = fig
|
274 |
+
|
275 |
+
certainties_max, certainties_max_idcs = torch.max(certainties_all_resized, dim=0)
|
276 |
+
H, W = certainties_max.shape
|
277 |
+
|
278 |
+
warps_max = warps_all_resized[certainties_max_idcs, torch.arange(H).unsqueeze(1), torch.arange(W)]
|
279 |
+
|
280 |
+
imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
|
281 |
+
imA = np.clip(imA * 255, 0, 255).astype(np.uint8)
|
282 |
+
|
283 |
+
return certainties_max, warps_max, certainties_max_idcs, imA, imB_compound, certainties_all_resized, warps_all_resized
|
284 |
+
|
285 |
+
|
286 |
+
|
287 |
+
def extract_keypoints_and_colors(imA, imB_compound, certainties_max, certainties_max_idcs, matches, roma_model,
|
288 |
+
verbose=False, output_dict={}):
|
289 |
+
"""
|
290 |
+
Extracts keypoints and corresponding colors from the source image (imA) and multiple target images (imB_compound).
|
291 |
+
|
292 |
+
Args:
|
293 |
+
imA: Source image as a NumPy array (H_A, W_A, C).
|
294 |
+
imB_compound: List of target images as NumPy arrays [(H_B, W_B, C), ...].
|
295 |
+
certainties_max: Tensor of pixel-wise maximum confidences.
|
296 |
+
certainties_max_idcs: Tensor of pixel-wise indices for the best matches.
|
297 |
+
matches: Matches in normalized coordinates.
|
298 |
+
roma_model: Roma model instance for keypoint operations.
|
299 |
+
verbose: if to show intermediate outputs and visualize results
|
300 |
+
|
301 |
+
Returns:
|
302 |
+
kptsA_np: Keypoints in imA in normalized coordinates.
|
303 |
+
kptsB_np: Keypoints in imB in normalized coordinates.
|
304 |
+
kptsA_color: Colors of keypoints in imA.
|
305 |
+
kptsB_color: Colors of keypoints in imB based on certainties_max_idcs.
|
306 |
+
"""
|
307 |
+
H_A, W_A, _ = imA.shape
|
308 |
+
H, W = certainties_max.shape
|
309 |
+
|
310 |
+
# Convert matches to pixel coordinates
|
311 |
+
kptsA, kptsB = roma_model.to_pixel_coordinates(
|
312 |
+
matches, W_A, H_A, H, W # W, H
|
313 |
+
)
|
314 |
+
|
315 |
+
kptsA_np = kptsA.detach().cpu().numpy()
|
316 |
+
kptsB_np = kptsB.detach().cpu().numpy()
|
317 |
+
kptsA_np = kptsA_np[:, [1, 0]]
|
318 |
+
|
319 |
+
if verbose:
|
320 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
321 |
+
cax = ax.imshow(imA)
|
322 |
+
ax.set_title("Reference image, imA")
|
323 |
+
output_dict[f'reference_image'] = fig
|
324 |
+
|
325 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
326 |
+
cax = ax.imshow(imB_compound[0])
|
327 |
+
ax.set_title("Image to compare to image, imB_compound")
|
328 |
+
output_dict[f'imB_compound'] = fig
|
329 |
+
|
330 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
331 |
+
cax = ax.imshow(np.flipud(imA))
|
332 |
+
cax = ax.scatter(kptsA_np[:, 0], H_A - kptsA_np[:, 1], s=.03)
|
333 |
+
ax.set_title("Keypoints in imA")
|
334 |
+
ax.set_xlim(0, W_A)
|
335 |
+
ax.set_ylim(0, H_A)
|
336 |
+
output_dict[f'kptsA'] = fig
|
337 |
+
|
338 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
339 |
+
cax = ax.imshow(np.flipud(imB_compound[0]))
|
340 |
+
cax = ax.scatter(kptsB_np[:, 0], H_A - kptsB_np[:, 1], s=.03)
|
341 |
+
ax.set_title("Keypoints in imB")
|
342 |
+
ax.set_xlim(0, W_A)
|
343 |
+
ax.set_ylim(0, H_A)
|
344 |
+
output_dict[f'kptsB'] = fig
|
345 |
+
|
346 |
+
# Keypoints are in format (row, column) so the first value is alwain in range [0;height] and second is in range[0;width]
|
347 |
+
|
348 |
+
kptsA_np = kptsA.detach().cpu().numpy()
|
349 |
+
kptsB_np = kptsB.detach().cpu().numpy()
|
350 |
+
|
351 |
+
# Extract colors for keypoints in imA (vectorized)
|
352 |
+
# New experimental version
|
353 |
+
kptsA_x = np.round(kptsA_np[:, 0] / 1.).astype(int)
|
354 |
+
kptsA_y = np.round(kptsA_np[:, 1] / 1.).astype(int)
|
355 |
+
kptsA_color = imA[np.clip(kptsA_x, 0, H - 1), np.clip(kptsA_y, 0, W - 1)]
|
356 |
+
|
357 |
+
# Create a composite image from imB_compound
|
358 |
+
imB_compound_np = np.stack(imB_compound, axis=0)
|
359 |
+
H_B, W_B, _ = imB_compound[0].shape
|
360 |
+
|
361 |
+
# Extract colors for keypoints in imB using certainties_max_idcs
|
362 |
+
imB_np = imB_compound_np[
|
363 |
+
certainties_max_idcs.detach().cpu().numpy(),
|
364 |
+
np.arange(H).reshape(-1, 1),
|
365 |
+
np.arange(W)
|
366 |
+
]
|
367 |
+
|
368 |
+
if verbose:
|
369 |
+
print("imB_np.shape:", imB_np.shape)
|
370 |
+
print("imB_np:", imB_np)
|
371 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
372 |
+
cax = ax.imshow(np.flipud(imB_np))
|
373 |
+
cax = ax.scatter(kptsB_np[:, 0], H_A - kptsB_np[:, 1], s=.03)
|
374 |
+
ax.set_title("np.flipud(imB_np[0]")
|
375 |
+
ax.set_xlim(0, W_A)
|
376 |
+
ax.set_ylim(0, H_A)
|
377 |
+
output_dict[f'np.flipud(imB_np[0]'] = fig
|
378 |
+
|
379 |
+
|
380 |
+
kptsB_x = np.round(kptsB_np[:, 0]).astype(int)
|
381 |
+
kptsB_y = np.round(kptsB_np[:, 1]).astype(int)
|
382 |
+
|
383 |
+
certainties_max_idcs_np = certainties_max_idcs.detach().cpu().numpy()
|
384 |
+
kptsB_proj_matrices_idx = certainties_max_idcs_np[np.clip(kptsA_x, 0, H - 1), np.clip(kptsA_y, 0, W - 1)]
|
385 |
+
kptsB_color = imB_compound_np[kptsB_proj_matrices_idx, np.clip(kptsB_y, 0, H - 1), np.clip(kptsB_x, 0, W - 1)]
|
386 |
+
|
387 |
+
# Normalize keypoints in both images
|
388 |
+
kptsA_np[:, 0] = kptsA_np[:, 0] / H * 2.0 - 1.0
|
389 |
+
kptsA_np[:, 1] = kptsA_np[:, 1] / W * 2.0 - 1.0
|
390 |
+
kptsB_np[:, 0] = kptsB_np[:, 0] / W_B * 2.0 - 1.0
|
391 |
+
kptsB_np[:, 1] = kptsB_np[:, 1] / H_B * 2.0 - 1.0
|
392 |
+
|
393 |
+
return kptsA_np[:, [1, 0]], kptsB_np, kptsB_proj_matrices_idx, kptsA_color, kptsB_color
|
394 |
+
|
395 |
+
def prepare_tensor(input_array, device):
|
396 |
+
"""
|
397 |
+
Converts an input array to a torch tensor, clones it, and detaches it for safe computation.
|
398 |
+
Args:
|
399 |
+
input_array (array-like): The input array to convert.
|
400 |
+
device (str or torch.device): The device to move the tensor to.
|
401 |
+
Returns:
|
402 |
+
torch.Tensor: A detached tensor clone of the input array on the specified device.
|
403 |
+
"""
|
404 |
+
if not isinstance(input_array, torch.Tensor):
|
405 |
+
return torch.tensor(input_array, dtype=torch.float32).to(device).clone().detach()
|
406 |
+
return input_array.clone().detach().to(device).to(torch.float32)
|
407 |
+
|
408 |
+
def triangulate_points(P1, P2, k1_x, k1_y, k2_x, k2_y, device="cuda"):
|
409 |
+
"""
|
410 |
+
Solves for a batch of 3D points given batches of projection matrices and corresponding image points.
|
411 |
+
|
412 |
+
Parameters:
|
413 |
+
- P1, P2: Tensors of projection matrices of size (batch_size, 4, 4) or (4, 4)
|
414 |
+
- k1_x, k1_y: Tensors of shape (batch_size,)
|
415 |
+
- k2_x, k2_y: Tensors of shape (batch_size,)
|
416 |
+
|
417 |
+
Returns:
|
418 |
+
- X: A tensor containing the 3D homogeneous coordinates, shape (batch_size, 4)
|
419 |
+
"""
|
420 |
+
EPS = 1e-4
|
421 |
+
# Ensure inputs are tensors
|
422 |
+
|
423 |
+
P1 = prepare_tensor(P1, device)
|
424 |
+
P2 = prepare_tensor(P2, device)
|
425 |
+
k1_x = prepare_tensor(k1_x, device)
|
426 |
+
k1_y = prepare_tensor(k1_y, device)
|
427 |
+
k2_x = prepare_tensor(k2_x, device)
|
428 |
+
k2_y = prepare_tensor(k2_y, device)
|
429 |
+
batch_size = k1_x.shape[0]
|
430 |
+
|
431 |
+
# Expand P1 and P2 if they are not batched
|
432 |
+
if P1.ndim == 2:
|
433 |
+
P1 = P1.unsqueeze(0).expand(batch_size, -1, -1)
|
434 |
+
if P2.ndim == 2:
|
435 |
+
P2 = P2.unsqueeze(0).expand(batch_size, -1, -1)
|
436 |
+
|
437 |
+
# Extract columns from P1 and P2
|
438 |
+
P1_0 = P1[:, :, 0] # Shape: (batch_size, 4)
|
439 |
+
P1_1 = P1[:, :, 1]
|
440 |
+
P1_2 = P1[:, :, 2]
|
441 |
+
|
442 |
+
P2_0 = P2[:, :, 0]
|
443 |
+
P2_1 = P2[:, :, 1]
|
444 |
+
P2_2 = P2[:, :, 2]
|
445 |
+
|
446 |
+
# Reshape kx and ky to (batch_size, 1)
|
447 |
+
k1_x = k1_x.view(-1, 1)
|
448 |
+
k1_y = k1_y.view(-1, 1)
|
449 |
+
k2_x = k2_x.view(-1, 1)
|
450 |
+
k2_y = k2_y.view(-1, 1)
|
451 |
+
|
452 |
+
# Construct the equations for each batch
|
453 |
+
# For camera 1
|
454 |
+
A1 = P1_0 - k1_x * P1_2 # Shape: (batch_size, 4)
|
455 |
+
A2 = P1_1 - k1_y * P1_2
|
456 |
+
# For camera 2
|
457 |
+
A3 = P2_0 - k2_x * P2_2
|
458 |
+
A4 = P2_1 - k2_y * P2_2
|
459 |
+
|
460 |
+
# Stack the equations
|
461 |
+
A = torch.stack([A1, A2, A3, A4], dim=1) # Shape: (batch_size, 4, 4)
|
462 |
+
|
463 |
+
# Right-hand side (constants)
|
464 |
+
b = -A[:, :, 3] # Shape: (batch_size, 4)
|
465 |
+
A_reduced = A[:, :, :3] # Coefficients of x, y, z
|
466 |
+
|
467 |
+
# Solve using torch.linalg.lstsq (supports batching)
|
468 |
+
X_xyz = torch.linalg.lstsq(A_reduced, b.unsqueeze(2)).solution.squeeze(2) # Shape: (batch_size, 3)
|
469 |
+
|
470 |
+
# Append 1 to get homogeneous coordinates
|
471 |
+
ones = torch.ones((batch_size, 1), dtype=torch.float32, device=X_xyz.device)
|
472 |
+
X = torch.cat([X_xyz, ones], dim=1) # Shape: (batch_size, 4)
|
473 |
+
|
474 |
+
# Now compute the errors of projections.
|
475 |
+
seeked_splats_proj1 = (X.unsqueeze(1) @ P1).squeeze(1)
|
476 |
+
seeked_splats_proj1 = seeked_splats_proj1 / (EPS + seeked_splats_proj1[:, [3]])
|
477 |
+
seeked_splats_proj2 = (X.unsqueeze(1) @ P2).squeeze(1)
|
478 |
+
seeked_splats_proj2 = seeked_splats_proj2 / (EPS + seeked_splats_proj2[:, [3]])
|
479 |
+
proj1_target = torch.concat([k1_x, k1_y], dim=1)
|
480 |
+
proj2_target = torch.concat([k2_x, k2_y], dim=1)
|
481 |
+
errors_proj1 = torch.abs(seeked_splats_proj1[:, :2] - proj1_target).sum(1).detach().cpu().numpy()
|
482 |
+
errors_proj2 = torch.abs(seeked_splats_proj2[:, :2] - proj2_target).sum(1).detach().cpu().numpy()
|
483 |
+
|
484 |
+
return X, errors_proj1, errors_proj2
|
485 |
+
|
486 |
+
|
487 |
+
|
488 |
+
def select_best_keypoints(
|
489 |
+
NNs_triangulated_points, NNs_errors_proj1, NNs_errors_proj2, device="cuda"):
|
490 |
+
"""
|
491 |
+
From all the points fitted to keypoints and corresponding colors from the source image (imA) and multiple target images (imB_compound).
|
492 |
+
|
493 |
+
Args:
|
494 |
+
NNs_triangulated_points: torch tensor with keypoints coordinates (num_nns, num_points, dim). dim can be arbitrary,
|
495 |
+
usually 3 or 4(for homogeneous representation).
|
496 |
+
NNs_errors_proj1: numpy array with projection error of the estimated keypoint on the reference frame (num_nns, num_points).
|
497 |
+
NNs_errors_proj2: numpy array with projection error of the estimated keypoint on the neighbor frame (num_nns, num_points).
|
498 |
+
Returns:
|
499 |
+
selected_keypoints: keypoints with the best score.
|
500 |
+
"""
|
501 |
+
|
502 |
+
NNs_errors_proj = np.maximum(NNs_errors_proj1, NNs_errors_proj2)
|
503 |
+
|
504 |
+
# Convert indices to PyTorch tensor
|
505 |
+
indices = torch.from_numpy(np.argmin(NNs_errors_proj, axis=0)).long().to(device)
|
506 |
+
|
507 |
+
# Create index tensor for the second dimension
|
508 |
+
n_indices = torch.arange(NNs_triangulated_points.shape[1]).long().to(device)
|
509 |
+
|
510 |
+
# Use advanced indexing to select elements
|
511 |
+
NNs_triangulated_points_selected = NNs_triangulated_points[indices, n_indices, :] # Shape: [N, k]
|
512 |
+
|
513 |
+
return NNs_triangulated_points_selected, np.min(NNs_errors_proj, axis=0)
|
514 |
+
|
515 |
+
|
516 |
+
|
517 |
+
def init_gaussians_with_corr(gaussians, scene, cfg, device, verbose = False, roma_model=None):
|
518 |
+
"""
|
519 |
+
For a given input gaussians and a scene we instantiate a RoMa model(change to indoors if necessary) and process scene
|
520 |
+
training frames to extract correspondences. Those are used to initialize gaussians
|
521 |
+
Args:
|
522 |
+
gaussians: object gaussians of the class GaussianModel that we need to enrich with gaussians.
|
523 |
+
scene: object of the Scene class.
|
524 |
+
cfg: configuration. Use init_wC
|
525 |
+
Returns:
|
526 |
+
gaussians: inplace transforms object gaussians of the class GaussianModel.
|
527 |
+
|
528 |
+
"""
|
529 |
+
if roma_model is None:
|
530 |
+
if cfg.roma_model == "indoors":
|
531 |
+
roma_model = roma_indoor(device=device)
|
532 |
+
else:
|
533 |
+
roma_model = roma_outdoor(device=device)
|
534 |
+
roma_model.upsample_preds = False
|
535 |
+
roma_model.symmetric = False
|
536 |
+
M = cfg.matches_per_ref
|
537 |
+
upper_thresh = roma_model.sample_thresh
|
538 |
+
scaling_factor = cfg.scaling_factor
|
539 |
+
expansion_factor = 1
|
540 |
+
keypoint_fit_error_tolerance = cfg.proj_err_tolerance
|
541 |
+
visualizations = {}
|
542 |
+
viewpoint_stack = scene.getTrainCameras().copy()
|
543 |
+
NUM_REFERENCE_FRAMES = min(cfg.num_refs, len(viewpoint_stack))
|
544 |
+
NUM_NNS_PER_REFERENCE = min(cfg.nns_per_ref , len(viewpoint_stack))
|
545 |
+
# Select cameras using K-means
|
546 |
+
viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
|
547 |
+
|
548 |
+
selected_indices = select_cameras_kmeans(cameras=viewpoint_cam_all.detach().cpu().numpy(), K=NUM_REFERENCE_FRAMES)
|
549 |
+
selected_indices = sorted(selected_indices)
|
550 |
+
|
551 |
+
|
552 |
+
# Find the k-closest vectors for each vector
|
553 |
+
viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
|
554 |
+
closest_indices = k_closest_vectors(viewpoint_cam_all, NUM_NNS_PER_REFERENCE)
|
555 |
+
if verbose: print("Indices of k-closest vectors for each vector:\n", closest_indices)
|
556 |
+
|
557 |
+
closest_indices_selected = closest_indices[:, :].detach().cpu().numpy()
|
558 |
+
|
559 |
+
all_new_xyz = []
|
560 |
+
all_new_features_dc = []
|
561 |
+
all_new_features_rest = []
|
562 |
+
all_new_opacities = []
|
563 |
+
all_new_scaling = []
|
564 |
+
all_new_rotation = []
|
565 |
+
|
566 |
+
# Run roma_model.match once to kinda initialize the model
|
567 |
+
with torch.no_grad():
|
568 |
+
viewpoint_cam1 = viewpoint_stack[0]
|
569 |
+
viewpoint_cam2 = viewpoint_stack[1]
|
570 |
+
imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
|
571 |
+
imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
|
572 |
+
imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
|
573 |
+
imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
|
574 |
+
warp, certainty_warp = roma_model.match(imA, imB, device=device)
|
575 |
+
print("Once run full roma_model.match warp.shape:", warp.shape)
|
576 |
+
print("Once run full roma_model.match certainty_warp.shape:", certainty_warp.shape)
|
577 |
+
del warp, certainty_warp
|
578 |
+
torch.cuda.empty_cache()
|
579 |
+
|
580 |
+
for source_idx in tqdm(sorted(selected_indices)):
|
581 |
+
# 1. Compute keypoints and warping for all the neigboring views
|
582 |
+
with torch.no_grad():
|
583 |
+
# Call the aggregation function to get imA and imB_compound
|
584 |
+
certainties_max, warps_max, certainties_max_idcs, imA, imB_compound, certainties_all, warps_all = aggregate_confidences_and_warps(
|
585 |
+
viewpoint_stack=viewpoint_stack,
|
586 |
+
closest_indices=closest_indices_selected,
|
587 |
+
roma_model=roma_model,
|
588 |
+
source_idx=source_idx,
|
589 |
+
verbose=verbose, output_dict=visualizations
|
590 |
+
)
|
591 |
+
|
592 |
+
|
593 |
+
# Triangulate keypoints
|
594 |
+
with torch.no_grad():
|
595 |
+
matches = warps_max
|
596 |
+
certainty = certainties_max
|
597 |
+
certainty = certainty.clone()
|
598 |
+
certainty[certainty > upper_thresh] = 1
|
599 |
+
matches, certainty = (
|
600 |
+
matches.reshape(-1, 4),
|
601 |
+
certainty.reshape(-1),
|
602 |
+
)
|
603 |
+
|
604 |
+
# Select based on certainty elements with high confidence. These are basically all of
|
605 |
+
# kptsA_np.
|
606 |
+
good_samples = torch.multinomial(certainty,
|
607 |
+
num_samples=min(expansion_factor * M, len(certainty)),
|
608 |
+
replacement=False)
|
609 |
+
|
610 |
+
certainties_max, warps_max, certainties_max_idcs, imA, imB_compound, certainties_all, warps_all
|
611 |
+
reference_image_dict = {
|
612 |
+
"ref_image": imA,
|
613 |
+
"NNs_images": imB_compound,
|
614 |
+
"certainties_all": certainties_all,
|
615 |
+
"warps_all": warps_all,
|
616 |
+
"triangulated_points": [],
|
617 |
+
"triangulated_points_errors_proj1": [],
|
618 |
+
"triangulated_points_errors_proj2": []
|
619 |
+
|
620 |
+
}
|
621 |
+
with torch.no_grad():
|
622 |
+
for NN_idx in tqdm(range(len(warps_all))):
|
623 |
+
matches_NN = warps_all[NN_idx].reshape(-1, 4)[good_samples]
|
624 |
+
|
625 |
+
# Extract keypoints and colors
|
626 |
+
kptsA_np, kptsB_np, kptsB_proj_matrices_idcs, kptsA_color, kptsB_color = extract_keypoints_and_colors(
|
627 |
+
imA, imB_compound, certainties_max, certainties_max_idcs, matches_NN, roma_model
|
628 |
+
)
|
629 |
+
|
630 |
+
proj_matrices_A = viewpoint_stack[source_idx].full_proj_transform
|
631 |
+
proj_matrices_B = viewpoint_stack[closest_indices_selected[source_idx, NN_idx]].full_proj_transform
|
632 |
+
triangulated_points, triangulated_points_errors_proj1, triangulated_points_errors_proj2 = triangulate_points(
|
633 |
+
P1=torch.stack([proj_matrices_A] * M, axis=0),
|
634 |
+
P2=torch.stack([proj_matrices_B] * M, axis=0),
|
635 |
+
k1_x=kptsA_np[:M, 0], k1_y=kptsA_np[:M, 1],
|
636 |
+
k2_x=kptsB_np[:M, 0], k2_y=kptsB_np[:M, 1])
|
637 |
+
|
638 |
+
reference_image_dict["triangulated_points"].append(triangulated_points)
|
639 |
+
reference_image_dict["triangulated_points_errors_proj1"].append(triangulated_points_errors_proj1)
|
640 |
+
reference_image_dict["triangulated_points_errors_proj2"].append(triangulated_points_errors_proj2)
|
641 |
+
|
642 |
+
with torch.no_grad():
|
643 |
+
NNs_triangulated_points_selected, NNs_triangulated_points_selected_proj_errors = select_best_keypoints(
|
644 |
+
NNs_triangulated_points=torch.stack(reference_image_dict["triangulated_points"], dim=0),
|
645 |
+
NNs_errors_proj1=np.stack(reference_image_dict["triangulated_points_errors_proj1"], axis=0),
|
646 |
+
NNs_errors_proj2=np.stack(reference_image_dict["triangulated_points_errors_proj2"], axis=0))
|
647 |
+
|
648 |
+
# 4. Save as gaussians
|
649 |
+
viewpoint_cam1 = viewpoint_stack[source_idx]
|
650 |
+
N = len(NNs_triangulated_points_selected)
|
651 |
+
with torch.no_grad():
|
652 |
+
new_xyz = NNs_triangulated_points_selected[:, :-1]
|
653 |
+
all_new_xyz.append(new_xyz) # seeked_splats
|
654 |
+
all_new_features_dc.append(RGB2SH(torch.tensor(kptsA_color.astype(np.float32) / 255.)).unsqueeze(1))
|
655 |
+
all_new_features_rest.append(torch.stack([gaussians._features_rest[-1].clone().detach() * 0.] * N, dim=0))
|
656 |
+
# new version that sets points with large error invisible
|
657 |
+
# TODO: remove those points instead. However it doesn't affect the performance.
|
658 |
+
mask_bad_points = torch.tensor(
|
659 |
+
NNs_triangulated_points_selected_proj_errors > keypoint_fit_error_tolerance,
|
660 |
+
dtype=torch.float32).unsqueeze(1).to(device)
|
661 |
+
all_new_opacities.append(torch.stack([gaussians._opacity[-1].clone().detach()] * N, dim=0) * 0. - mask_bad_points * (1e1))
|
662 |
+
|
663 |
+
dist_points_to_cam1 = torch.linalg.norm(viewpoint_cam1.camera_center.clone().detach() - new_xyz,
|
664 |
+
dim=1, ord=2)
|
665 |
+
#all_new_scaling.append(torch.log(((dist_points_to_cam1) / 1. * scaling_factor).unsqueeze(1).repeat(1, 3)))
|
666 |
+
all_new_scaling.append(gaussians.scaling_inverse_activation((dist_points_to_cam1 * scaling_factor).unsqueeze(1).repeat(1, 3)))
|
667 |
+
all_new_rotation.append(torch.stack([gaussians._rotation[-1].clone().detach()] * N, dim=0))
|
668 |
+
|
669 |
+
all_new_xyz = torch.cat(all_new_xyz, dim=0)
|
670 |
+
all_new_features_dc = torch.cat(all_new_features_dc, dim=0)
|
671 |
+
new_tmp_radii = torch.zeros(all_new_xyz.shape[0])
|
672 |
+
prune_mask = torch.ones(all_new_xyz.shape[0], dtype=torch.bool)
|
673 |
+
|
674 |
+
gaussians.densification_postfix(all_new_xyz[prune_mask].to(device),
|
675 |
+
all_new_features_dc[prune_mask].to(device),
|
676 |
+
torch.cat(all_new_features_rest, dim=0)[prune_mask].to(device),
|
677 |
+
torch.cat(all_new_opacities, dim=0)[prune_mask].to(device),
|
678 |
+
torch.cat(all_new_scaling, dim=0)[prune_mask].to(device),
|
679 |
+
torch.cat(all_new_rotation, dim=0)[prune_mask].to(device),
|
680 |
+
new_tmp_radii[prune_mask].to(device))
|
681 |
+
|
682 |
+
return viewpoint_stack, closest_indices_selected, visualizations
|
source/corr_init_new.py
ADDED
@@ -0,0 +1,904 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('../')
|
3 |
+
sys.path.append("../submodules")
|
4 |
+
sys.path.append('../submodules/RoMa')
|
5 |
+
|
6 |
+
from matplotlib import pyplot as plt
|
7 |
+
from PIL import Image
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
#from tqdm import tqdm_notebook as tqdm
|
12 |
+
from tqdm import tqdm
|
13 |
+
from scipy.cluster.vq import kmeans, vq
|
14 |
+
from scipy.spatial.distance import cdist
|
15 |
+
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from romatch import roma_outdoor, roma_indoor
|
18 |
+
from utils.sh_utils import RGB2SH
|
19 |
+
from romatch.utils import get_tuple_transform_ops
|
20 |
+
|
21 |
+
|
22 |
+
def pairwise_distances(matrix):
|
23 |
+
"""
|
24 |
+
Computes the pairwise Euclidean distances between all vectors in the input matrix.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
torch.Tensor: Pairwise distance matrix of shape [N, N].
|
31 |
+
"""
|
32 |
+
# Compute squared pairwise distances
|
33 |
+
squared_diff = torch.cdist(matrix, matrix, p=2)
|
34 |
+
return squared_diff
|
35 |
+
|
36 |
+
|
37 |
+
def k_closest_vectors(matrix, k):
|
38 |
+
"""
|
39 |
+
Finds the k-closest vectors for each vector in the input matrix based on Euclidean distance.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
|
43 |
+
k (int): Number of closest vectors to return for each vector.
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
torch.Tensor: Indices of the k-closest vectors for each vector, excluding the vector itself.
|
47 |
+
"""
|
48 |
+
# Compute pairwise distances
|
49 |
+
distances = pairwise_distances(matrix)
|
50 |
+
|
51 |
+
# For each vector, sort distances and get the indices of the k-closest vectors (excluding itself)
|
52 |
+
# Set diagonal distances to infinity to exclude the vector itself from the nearest neighbors
|
53 |
+
distances.fill_diagonal_(float('inf'))
|
54 |
+
|
55 |
+
# Get the indices of the k smallest distances (k-closest vectors)
|
56 |
+
_, indices = torch.topk(distances, k, largest=False, dim=1)
|
57 |
+
|
58 |
+
return indices
|
59 |
+
|
60 |
+
|
61 |
+
def select_cameras_kmeans(cameras, K):
|
62 |
+
"""
|
63 |
+
Selects K cameras from a set using K-means clustering.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
cameras: NumPy array of shape (N, 16), representing N cameras with their 4x4 homogeneous matrices flattened.
|
67 |
+
K: Number of clusters (cameras to select).
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
selected_indices: List of indices of the cameras closest to the cluster centers.
|
71 |
+
"""
|
72 |
+
# Ensure input is a NumPy array
|
73 |
+
if not isinstance(cameras, np.ndarray):
|
74 |
+
cameras = np.asarray(cameras)
|
75 |
+
|
76 |
+
if cameras.shape[1] != 16:
|
77 |
+
raise ValueError("Each camera must have 16 values corresponding to a flattened 4x4 matrix.")
|
78 |
+
|
79 |
+
# Perform K-means clustering
|
80 |
+
cluster_centers, _ = kmeans(cameras, K)
|
81 |
+
|
82 |
+
# Assign each camera to a cluster and find distances to cluster centers
|
83 |
+
cluster_assignments, _ = vq(cameras, cluster_centers)
|
84 |
+
|
85 |
+
# Find the camera nearest to each cluster center
|
86 |
+
selected_indices = []
|
87 |
+
for k in range(K):
|
88 |
+
cluster_members = cameras[cluster_assignments == k]
|
89 |
+
distances = cdist([cluster_centers[k]], cluster_members)[0]
|
90 |
+
nearest_camera_idx = np.where(cluster_assignments == k)[0][np.argmin(distances)]
|
91 |
+
selected_indices.append(nearest_camera_idx)
|
92 |
+
|
93 |
+
return selected_indices
|
94 |
+
|
95 |
+
|
96 |
+
def compute_warp_and_confidence(viewpoint_cam1, viewpoint_cam2, roma_model, device="cuda", verbose=False, output_dict={}):
|
97 |
+
"""
|
98 |
+
Computes the warp and confidence between two viewpoint cameras using the roma_model.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
viewpoint_cam1: Source viewpoint camera.
|
102 |
+
viewpoint_cam2: Target viewpoint camera.
|
103 |
+
roma_model: Pre-trained Roma model for correspondence matching.
|
104 |
+
device: Device to run the computation on.
|
105 |
+
verbose: If True, displays the images.
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
certainty: Confidence tensor.
|
109 |
+
warp: Warp tensor.
|
110 |
+
imB: Processed image B as numpy array.
|
111 |
+
"""
|
112 |
+
# Prepare images
|
113 |
+
imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
|
114 |
+
imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
|
115 |
+
imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
|
116 |
+
imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
|
117 |
+
|
118 |
+
if verbose:
|
119 |
+
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(16, 8))
|
120 |
+
cax1 = ax[0].imshow(imA)
|
121 |
+
ax[0].set_title("Image 1")
|
122 |
+
cax2 = ax[1].imshow(imB)
|
123 |
+
ax[1].set_title("Image 2")
|
124 |
+
fig.colorbar(cax1, ax=ax[0])
|
125 |
+
fig.colorbar(cax2, ax=ax[1])
|
126 |
+
|
127 |
+
for axis in ax:
|
128 |
+
axis.axis('off')
|
129 |
+
# Save the figure into the dictionary
|
130 |
+
output_dict[f'image_pair'] = fig
|
131 |
+
|
132 |
+
# Transform images
|
133 |
+
ws, hs = roma_model.w_resized, roma_model.h_resized
|
134 |
+
test_transform = get_tuple_transform_ops(resize=(hs, ws), normalize=True)
|
135 |
+
im_A, im_B = test_transform((imA, imB))
|
136 |
+
batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
|
137 |
+
|
138 |
+
# Forward pass through Roma model
|
139 |
+
corresps = roma_model.forward(batch) if not roma_model.symmetric else roma_model.forward_symmetric(batch)
|
140 |
+
finest_scale = 1
|
141 |
+
hs, ws = roma_model.upsample_res if roma_model.upsample_preds else (hs, ws)
|
142 |
+
|
143 |
+
# Process certainty and warp
|
144 |
+
certainty = corresps[finest_scale]["certainty"]
|
145 |
+
im_A_to_im_B = corresps[finest_scale]["flow"]
|
146 |
+
if roma_model.attenuate_cert:
|
147 |
+
low_res_certainty = F.interpolate(
|
148 |
+
corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
|
149 |
+
)
|
150 |
+
certainty -= 0.5 * low_res_certainty * (low_res_certainty < 0)
|
151 |
+
|
152 |
+
# Upsample predictions if needed
|
153 |
+
if roma_model.upsample_preds:
|
154 |
+
im_A_to_im_B = F.interpolate(
|
155 |
+
im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
|
156 |
+
)
|
157 |
+
certainty = F.interpolate(
|
158 |
+
certainty, size=(hs, ws), align_corners=False, mode="bilinear"
|
159 |
+
)
|
160 |
+
|
161 |
+
# Convert predictions to final format
|
162 |
+
im_A_to_im_B = im_A_to_im_B.permute(0, 2, 3, 1)
|
163 |
+
im_A_coords = torch.stack(torch.meshgrid(
|
164 |
+
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
|
165 |
+
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
|
166 |
+
indexing='ij'
|
167 |
+
), dim=0).permute(1, 2, 0).unsqueeze(0).expand(im_A_to_im_B.size(0), -1, -1, -1)
|
168 |
+
|
169 |
+
warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
|
170 |
+
certainty = certainty.sigmoid()
|
171 |
+
|
172 |
+
return certainty[0, 0], warp[0], np.array(imB)
|
173 |
+
|
174 |
+
|
175 |
+
def resize_batch(tensors_3d, tensors_4d, target_shape):
|
176 |
+
"""
|
177 |
+
Resizes a batch of tensors with shapes [B, H, W] and [B, H, W, 4] to the target spatial dimensions.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
tensors_3d: Tensor of shape [B, H, W].
|
181 |
+
tensors_4d: Tensor of shape [B, H, W, 4].
|
182 |
+
target_shape: Tuple (target_H, target_W) specifying the target spatial dimensions.
|
183 |
+
|
184 |
+
Returns:
|
185 |
+
resized_tensors_3d: Tensor of shape [B, target_H, target_W].
|
186 |
+
resized_tensors_4d: Tensor of shape [B, target_H, target_W, 4].
|
187 |
+
"""
|
188 |
+
target_H, target_W = target_shape
|
189 |
+
|
190 |
+
# Resize [B, H, W] tensor
|
191 |
+
resized_tensors_3d = F.interpolate(
|
192 |
+
tensors_3d.unsqueeze(1), size=(target_H, target_W), mode="bilinear", align_corners=False
|
193 |
+
).squeeze(1)
|
194 |
+
|
195 |
+
# Resize [B, H, W, 4] tensor
|
196 |
+
B, _, _, C = tensors_4d.shape
|
197 |
+
resized_tensors_4d = F.interpolate(
|
198 |
+
tensors_4d.permute(0, 3, 1, 2), size=(target_H, target_W), mode="bilinear", align_corners=False
|
199 |
+
).permute(0, 2, 3, 1)
|
200 |
+
|
201 |
+
return resized_tensors_3d, resized_tensors_4d
|
202 |
+
|
203 |
+
|
204 |
+
def aggregate_confidences_and_warps(viewpoint_stack, closest_indices, roma_model, source_idx, verbose=False, output_dict={}):
|
205 |
+
"""
|
206 |
+
Aggregates confidences and warps by iterating over the nearest neighbors of the source viewpoint.
|
207 |
+
|
208 |
+
Args:
|
209 |
+
viewpoint_stack: Stack of viewpoint cameras.
|
210 |
+
closest_indices: Indices of the nearest neighbors for each viewpoint.
|
211 |
+
roma_model: Pre-trained Roma model.
|
212 |
+
source_idx: Index of the source viewpoint.
|
213 |
+
verbose: If True, displays intermediate results.
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
certainties_max: Aggregated maximum confidences.
|
217 |
+
warps_max: Aggregated warps corresponding to maximum confidences.
|
218 |
+
certainties_max_idcs: Pixel-wise index of the image from which we taken the best matching.
|
219 |
+
imB_compound: List of the neighboring images.
|
220 |
+
"""
|
221 |
+
certainties_all, warps_all, imB_compound = [], [], []
|
222 |
+
|
223 |
+
for nn in tqdm(closest_indices[source_idx]):
|
224 |
+
|
225 |
+
viewpoint_cam1 = viewpoint_stack[source_idx]
|
226 |
+
viewpoint_cam2 = viewpoint_stack[nn]
|
227 |
+
|
228 |
+
certainty, warp, imB = compute_warp_and_confidence(viewpoint_cam1, viewpoint_cam2, roma_model, verbose=verbose, output_dict=output_dict)
|
229 |
+
certainties_all.append(certainty)
|
230 |
+
warps_all.append(warp)
|
231 |
+
imB_compound.append(imB)
|
232 |
+
|
233 |
+
certainties_all = torch.stack(certainties_all, dim=0)
|
234 |
+
target_shape = imB_compound[0].shape[:2]
|
235 |
+
if verbose:
|
236 |
+
print("certainties_all.shape:", certainties_all.shape)
|
237 |
+
print("torch.stack(warps_all, dim=0).shape:", torch.stack(warps_all, dim=0).shape)
|
238 |
+
print("target_shape:", target_shape)
|
239 |
+
|
240 |
+
certainties_all_resized, warps_all_resized = resize_batch(certainties_all,
|
241 |
+
torch.stack(warps_all, dim=0),
|
242 |
+
target_shape
|
243 |
+
)
|
244 |
+
|
245 |
+
if verbose:
|
246 |
+
print("warps_all_resized.shape:", warps_all_resized.shape)
|
247 |
+
for n, cert in enumerate(certainties_all):
|
248 |
+
fig, ax = plt.subplots()
|
249 |
+
cax = ax.imshow(cert.cpu().numpy(), cmap='viridis')
|
250 |
+
fig.colorbar(cax, ax=ax)
|
251 |
+
ax.set_title("Pixel-wise Confidence")
|
252 |
+
output_dict[f'certainty_{n}'] = fig
|
253 |
+
|
254 |
+
for n, warp in enumerate(warps_all):
|
255 |
+
fig, ax = plt.subplots()
|
256 |
+
cax = ax.imshow(warp.cpu().numpy()[:, :, :3], cmap='viridis')
|
257 |
+
fig.colorbar(cax, ax=ax)
|
258 |
+
ax.set_title("Pixel-wise warp")
|
259 |
+
output_dict[f'warp_resized_{n}'] = fig
|
260 |
+
|
261 |
+
for n, cert in enumerate(certainties_all_resized):
|
262 |
+
fig, ax = plt.subplots()
|
263 |
+
cax = ax.imshow(cert.cpu().numpy(), cmap='viridis')
|
264 |
+
fig.colorbar(cax, ax=ax)
|
265 |
+
ax.set_title("Pixel-wise Confidence resized")
|
266 |
+
output_dict[f'certainty_resized_{n}'] = fig
|
267 |
+
|
268 |
+
for n, warp in enumerate(warps_all_resized):
|
269 |
+
fig, ax = plt.subplots()
|
270 |
+
cax = ax.imshow(warp.cpu().numpy()[:, :, :3], cmap='viridis')
|
271 |
+
fig.colorbar(cax, ax=ax)
|
272 |
+
ax.set_title("Pixel-wise warp resized")
|
273 |
+
output_dict[f'warp_resized_{n}'] = fig
|
274 |
+
|
275 |
+
certainties_max, certainties_max_idcs = torch.max(certainties_all_resized, dim=0)
|
276 |
+
H, W = certainties_max.shape
|
277 |
+
|
278 |
+
warps_max = warps_all_resized[certainties_max_idcs, torch.arange(H).unsqueeze(1), torch.arange(W)]
|
279 |
+
|
280 |
+
|
281 |
+
return certainties_max, warps_max, certainties_max_idcs, imB_compound, certainties_all_resized, warps_all_resized
|
282 |
+
|
283 |
+
|
284 |
+
|
285 |
+
def extract_keypoints_and_colors(imA, imB_compound, certainties_max, certainties_max_idcs, matches, roma_model,
|
286 |
+
verbose=False, output_dict={}):
|
287 |
+
"""
|
288 |
+
Extracts keypoints and corresponding colors from the source image (imA) and multiple target images (imB_compound).
|
289 |
+
|
290 |
+
Args:
|
291 |
+
imA: Source image as a NumPy array (H_A, W_A, C).
|
292 |
+
imB_compound: List of target images as NumPy arrays [(H_B, W_B, C), ...].
|
293 |
+
certainties_max: Tensor of pixel-wise maximum confidences.
|
294 |
+
certainties_max_idcs: Tensor of pixel-wise indices for the best matches.
|
295 |
+
matches: Matches in normalized coordinates.
|
296 |
+
roma_model: Roma model instance for keypoint operations.
|
297 |
+
verbose: if to show intermediate outputs and visualize results
|
298 |
+
|
299 |
+
Returns:
|
300 |
+
kptsA_np: Keypoints in imA in normalized coordinates.
|
301 |
+
kptsB_np: Keypoints in imB in normalized coordinates.
|
302 |
+
kptsA_color: Colors of keypoints in imA.
|
303 |
+
kptsB_color: Colors of keypoints in imB based on certainties_max_idcs.
|
304 |
+
"""
|
305 |
+
H_A, W_A, _ = imA.shape
|
306 |
+
H, W = certainties_max.shape
|
307 |
+
|
308 |
+
# Convert matches to pixel coordinates
|
309 |
+
kptsA, kptsB = roma_model.to_pixel_coordinates(
|
310 |
+
matches, W_A, H_A, H, W # W, H
|
311 |
+
)
|
312 |
+
|
313 |
+
kptsA_np = kptsA.detach().cpu().numpy()
|
314 |
+
kptsB_np = kptsB.detach().cpu().numpy()
|
315 |
+
kptsA_np = kptsA_np[:, [1, 0]]
|
316 |
+
|
317 |
+
if verbose:
|
318 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
319 |
+
cax = ax.imshow(imA)
|
320 |
+
ax.set_title("Reference image, imA")
|
321 |
+
output_dict[f'reference_image'] = fig
|
322 |
+
|
323 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
324 |
+
cax = ax.imshow(imB_compound[0])
|
325 |
+
ax.set_title("Image to compare to image, imB_compound")
|
326 |
+
output_dict[f'imB_compound'] = fig
|
327 |
+
|
328 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
329 |
+
cax = ax.imshow(np.flipud(imA))
|
330 |
+
cax = ax.scatter(kptsA_np[:, 0], H_A - kptsA_np[:, 1], s=.03)
|
331 |
+
ax.set_title("Keypoints in imA")
|
332 |
+
ax.set_xlim(0, W_A)
|
333 |
+
ax.set_ylim(0, H_A)
|
334 |
+
output_dict[f'kptsA'] = fig
|
335 |
+
|
336 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
337 |
+
cax = ax.imshow(np.flipud(imB_compound[0]))
|
338 |
+
cax = ax.scatter(kptsB_np[:, 0], H_A - kptsB_np[:, 1], s=.03)
|
339 |
+
ax.set_title("Keypoints in imB")
|
340 |
+
ax.set_xlim(0, W_A)
|
341 |
+
ax.set_ylim(0, H_A)
|
342 |
+
output_dict[f'kptsB'] = fig
|
343 |
+
|
344 |
+
# Keypoints are in format (row, column) so the first value is alwain in range [0;height] and second is in range[0;width]
|
345 |
+
|
346 |
+
kptsA_np = kptsA.detach().cpu().numpy()
|
347 |
+
kptsB_np = kptsB.detach().cpu().numpy()
|
348 |
+
|
349 |
+
# Extract colors for keypoints in imA (vectorized)
|
350 |
+
# New experimental version
|
351 |
+
kptsA_x = np.round(kptsA_np[:, 0] / 1.).astype(int)
|
352 |
+
kptsA_y = np.round(kptsA_np[:, 1] / 1.).astype(int)
|
353 |
+
kptsA_color = imA[np.clip(kptsA_x, 0, H - 1), np.clip(kptsA_y, 0, W - 1)]
|
354 |
+
|
355 |
+
# Create a composite image from imB_compound
|
356 |
+
imB_compound_np = np.stack(imB_compound, axis=0)
|
357 |
+
H_B, W_B, _ = imB_compound[0].shape
|
358 |
+
|
359 |
+
# Extract colors for keypoints in imB using certainties_max_idcs
|
360 |
+
imB_np = imB_compound_np[
|
361 |
+
certainties_max_idcs.detach().cpu().numpy(),
|
362 |
+
np.arange(H).reshape(-1, 1),
|
363 |
+
np.arange(W)
|
364 |
+
]
|
365 |
+
|
366 |
+
if verbose:
|
367 |
+
print("imB_np.shape:", imB_np.shape)
|
368 |
+
print("imB_np:", imB_np)
|
369 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
370 |
+
cax = ax.imshow(np.flipud(imB_np))
|
371 |
+
cax = ax.scatter(kptsB_np[:, 0], H_A - kptsB_np[:, 1], s=.03)
|
372 |
+
ax.set_title("np.flipud(imB_np[0]")
|
373 |
+
ax.set_xlim(0, W_A)
|
374 |
+
ax.set_ylim(0, H_A)
|
375 |
+
output_dict[f'np.flipud(imB_np[0]'] = fig
|
376 |
+
|
377 |
+
|
378 |
+
kptsB_x = np.round(kptsB_np[:, 0]).astype(int)
|
379 |
+
kptsB_y = np.round(kptsB_np[:, 1]).astype(int)
|
380 |
+
|
381 |
+
certainties_max_idcs_np = certainties_max_idcs.detach().cpu().numpy()
|
382 |
+
kptsB_proj_matrices_idx = certainties_max_idcs_np[np.clip(kptsA_x, 0, H - 1), np.clip(kptsA_y, 0, W - 1)]
|
383 |
+
kptsB_color = imB_compound_np[kptsB_proj_matrices_idx, np.clip(kptsB_y, 0, H - 1), np.clip(kptsB_x, 0, W - 1)]
|
384 |
+
|
385 |
+
# Normalize keypoints in both images
|
386 |
+
kptsA_np[:, 0] = kptsA_np[:, 0] / H * 2.0 - 1.0
|
387 |
+
kptsA_np[:, 1] = kptsA_np[:, 1] / W * 2.0 - 1.0
|
388 |
+
kptsB_np[:, 0] = kptsB_np[:, 0] / W_B * 2.0 - 1.0
|
389 |
+
kptsB_np[:, 1] = kptsB_np[:, 1] / H_B * 2.0 - 1.0
|
390 |
+
|
391 |
+
return kptsA_np[:, [1, 0]], kptsB_np, kptsB_proj_matrices_idx, kptsA_color, kptsB_color
|
392 |
+
|
393 |
+
def prepare_tensor(input_array, device):
|
394 |
+
"""
|
395 |
+
Converts an input array to a torch tensor, clones it, and detaches it for safe computation.
|
396 |
+
Args:
|
397 |
+
input_array (array-like): The input array to convert.
|
398 |
+
device (str or torch.device): The device to move the tensor to.
|
399 |
+
Returns:
|
400 |
+
torch.Tensor: A detached tensor clone of the input array on the specified device.
|
401 |
+
"""
|
402 |
+
if not isinstance(input_array, torch.Tensor):
|
403 |
+
return torch.tensor(input_array, dtype=torch.float32).to(device).clone().detach()
|
404 |
+
return input_array.clone().detach().to(device).to(torch.float32)
|
405 |
+
|
406 |
+
def triangulate_points(P1, P2, k1_x, k1_y, k2_x, k2_y, device="cuda"):
|
407 |
+
"""
|
408 |
+
Solves for a batch of 3D points given batches of projection matrices and corresponding image points.
|
409 |
+
|
410 |
+
Parameters:
|
411 |
+
- P1, P2: Tensors of projection matrices of size (batch_size, 4, 4) or (4, 4)
|
412 |
+
- k1_x, k1_y: Tensors of shape (batch_size,)
|
413 |
+
- k2_x, k2_y: Tensors of shape (batch_size,)
|
414 |
+
|
415 |
+
Returns:
|
416 |
+
- X: A tensor containing the 3D homogeneous coordinates, shape (batch_size, 4)
|
417 |
+
"""
|
418 |
+
EPS = 1e-4
|
419 |
+
# Ensure inputs are tensors
|
420 |
+
|
421 |
+
P1 = prepare_tensor(P1, device)
|
422 |
+
P2 = prepare_tensor(P2, device)
|
423 |
+
k1_x = prepare_tensor(k1_x, device)
|
424 |
+
k1_y = prepare_tensor(k1_y, device)
|
425 |
+
k2_x = prepare_tensor(k2_x, device)
|
426 |
+
k2_y = prepare_tensor(k2_y, device)
|
427 |
+
batch_size = k1_x.shape[0]
|
428 |
+
|
429 |
+
# Expand P1 and P2 if they are not batched
|
430 |
+
if P1.ndim == 2:
|
431 |
+
P1 = P1.unsqueeze(0).expand(batch_size, -1, -1)
|
432 |
+
if P2.ndim == 2:
|
433 |
+
P2 = P2.unsqueeze(0).expand(batch_size, -1, -1)
|
434 |
+
|
435 |
+
# Extract columns from P1 and P2
|
436 |
+
P1_0 = P1[:, :, 0] # Shape: (batch_size, 4)
|
437 |
+
P1_1 = P1[:, :, 1]
|
438 |
+
P1_2 = P1[:, :, 2]
|
439 |
+
|
440 |
+
P2_0 = P2[:, :, 0]
|
441 |
+
P2_1 = P2[:, :, 1]
|
442 |
+
P2_2 = P2[:, :, 2]
|
443 |
+
|
444 |
+
# Reshape kx and ky to (batch_size, 1)
|
445 |
+
k1_x = k1_x.view(-1, 1)
|
446 |
+
k1_y = k1_y.view(-1, 1)
|
447 |
+
k2_x = k2_x.view(-1, 1)
|
448 |
+
k2_y = k2_y.view(-1, 1)
|
449 |
+
|
450 |
+
# Construct the equations for each batch
|
451 |
+
# For camera 1
|
452 |
+
A1 = P1_0 - k1_x * P1_2 # Shape: (batch_size, 4)
|
453 |
+
A2 = P1_1 - k1_y * P1_2
|
454 |
+
# For camera 2
|
455 |
+
A3 = P2_0 - k2_x * P2_2
|
456 |
+
A4 = P2_1 - k2_y * P2_2
|
457 |
+
|
458 |
+
# Stack the equations
|
459 |
+
A = torch.stack([A1, A2, A3, A4], dim=1) # Shape: (batch_size, 4, 4)
|
460 |
+
|
461 |
+
# Right-hand side (constants)
|
462 |
+
b = -A[:, :, 3] # Shape: (batch_size, 4)
|
463 |
+
A_reduced = A[:, :, :3] # Coefficients of x, y, z
|
464 |
+
|
465 |
+
# Solve using torch.linalg.lstsq (supports batching)
|
466 |
+
X_xyz = torch.linalg.lstsq(A_reduced, b.unsqueeze(2)).solution.squeeze(2) # Shape: (batch_size, 3)
|
467 |
+
|
468 |
+
# Append 1 to get homogeneous coordinates
|
469 |
+
ones = torch.ones((batch_size, 1), dtype=torch.float32, device=X_xyz.device)
|
470 |
+
X = torch.cat([X_xyz, ones], dim=1) # Shape: (batch_size, 4)
|
471 |
+
|
472 |
+
# Now compute the errors of projections.
|
473 |
+
seeked_splats_proj1 = (X.unsqueeze(1) @ P1).squeeze(1)
|
474 |
+
seeked_splats_proj1 = seeked_splats_proj1 / (EPS + seeked_splats_proj1[:, [3]])
|
475 |
+
seeked_splats_proj2 = (X.unsqueeze(1) @ P2).squeeze(1)
|
476 |
+
seeked_splats_proj2 = seeked_splats_proj2 / (EPS + seeked_splats_proj2[:, [3]])
|
477 |
+
proj1_target = torch.concat([k1_x, k1_y], dim=1)
|
478 |
+
proj2_target = torch.concat([k2_x, k2_y], dim=1)
|
479 |
+
errors_proj1 = torch.abs(seeked_splats_proj1[:, :2] - proj1_target).sum(1).detach().cpu().numpy()
|
480 |
+
errors_proj2 = torch.abs(seeked_splats_proj2[:, :2] - proj2_target).sum(1).detach().cpu().numpy()
|
481 |
+
|
482 |
+
return X, errors_proj1, errors_proj2
|
483 |
+
|
484 |
+
|
485 |
+
|
486 |
+
def select_best_keypoints(
|
487 |
+
NNs_triangulated_points, NNs_errors_proj1, NNs_errors_proj2, device="cuda"):
|
488 |
+
"""
|
489 |
+
From all the points fitted to keypoints and corresponding colors from the source image (imA) and multiple target images (imB_compound).
|
490 |
+
|
491 |
+
Args:
|
492 |
+
NNs_triangulated_points: torch tensor with keypoints coordinates (num_nns, num_points, dim). dim can be arbitrary,
|
493 |
+
usually 3 or 4(for homogeneous representation).
|
494 |
+
NNs_errors_proj1: numpy array with projection error of the estimated keypoint on the reference frame (num_nns, num_points).
|
495 |
+
NNs_errors_proj2: numpy array with projection error of the estimated keypoint on the neighbor frame (num_nns, num_points).
|
496 |
+
Returns:
|
497 |
+
selected_keypoints: keypoints with the best score.
|
498 |
+
"""
|
499 |
+
|
500 |
+
NNs_errors_proj = np.maximum(NNs_errors_proj1, NNs_errors_proj2)
|
501 |
+
|
502 |
+
# Convert indices to PyTorch tensor
|
503 |
+
indices = torch.from_numpy(np.argmin(NNs_errors_proj, axis=0)).long().to(device)
|
504 |
+
|
505 |
+
# Create index tensor for the second dimension
|
506 |
+
n_indices = torch.arange(NNs_triangulated_points.shape[1]).long().to(device)
|
507 |
+
|
508 |
+
# Use advanced indexing to select elements
|
509 |
+
NNs_triangulated_points_selected = NNs_triangulated_points[indices, n_indices, :] # Shape: [N, k]
|
510 |
+
|
511 |
+
return NNs_triangulated_points_selected, np.min(NNs_errors_proj, axis=0)
|
512 |
+
|
513 |
+
|
514 |
+
|
515 |
+
import time
|
516 |
+
from collections import defaultdict
|
517 |
+
from tqdm import tqdm
|
518 |
+
|
519 |
+
# def init_gaussians_with_corr_profiled(gaussians, scene, cfg, device, verbose=False, roma_model=None):
|
520 |
+
# timings = defaultdict(list) # To accumulate timings
|
521 |
+
|
522 |
+
# if roma_model is None:
|
523 |
+
# if cfg.roma_model == "indoors":
|
524 |
+
# roma_model = roma_indoor(device=device)
|
525 |
+
# else:
|
526 |
+
# roma_model = roma_outdoor(device=device)
|
527 |
+
# roma_model.upsample_preds = False
|
528 |
+
# roma_model.symmetric = False
|
529 |
+
|
530 |
+
# M = cfg.matches_per_ref
|
531 |
+
# upper_thresh = roma_model.sample_thresh
|
532 |
+
# scaling_factor = cfg.scaling_factor
|
533 |
+
# expansion_factor = 1
|
534 |
+
# keypoint_fit_error_tolerance = cfg.proj_err_tolerance
|
535 |
+
# visualizations = {}
|
536 |
+
# viewpoint_stack = scene.getTrainCameras().copy()
|
537 |
+
# NUM_REFERENCE_FRAMES = min(cfg.num_refs, len(viewpoint_stack))
|
538 |
+
# NUM_NNS_PER_REFERENCE = min(cfg.nns_per_ref, len(viewpoint_stack))
|
539 |
+
|
540 |
+
# viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
|
541 |
+
|
542 |
+
# selected_indices = select_cameras_kmeans(cameras=viewpoint_cam_all.detach().cpu().numpy(), K=NUM_REFERENCE_FRAMES)
|
543 |
+
# selected_indices = sorted(selected_indices)
|
544 |
+
|
545 |
+
# viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
|
546 |
+
# closest_indices = k_closest_vectors(viewpoint_cam_all, NUM_NNS_PER_REFERENCE)
|
547 |
+
# closest_indices_selected = closest_indices[:, :].detach().cpu().numpy()
|
548 |
+
|
549 |
+
# all_new_xyz = []
|
550 |
+
# all_new_features_dc = []
|
551 |
+
# all_new_features_rest = []
|
552 |
+
# all_new_opacities = []
|
553 |
+
# all_new_scaling = []
|
554 |
+
# all_new_rotation = []
|
555 |
+
|
556 |
+
# # Dummy first pass to initialize model
|
557 |
+
# with torch.no_grad():
|
558 |
+
# viewpoint_cam1 = viewpoint_stack[0]
|
559 |
+
# viewpoint_cam2 = viewpoint_stack[1]
|
560 |
+
# imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
|
561 |
+
# imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
|
562 |
+
# imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
|
563 |
+
# imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
|
564 |
+
# warp, certainty_warp = roma_model.match(imA, imB, device=device)
|
565 |
+
# del warp, certainty_warp
|
566 |
+
# torch.cuda.empty_cache()
|
567 |
+
|
568 |
+
# # Main Loop over source_idx
|
569 |
+
# for source_idx in tqdm(sorted(selected_indices), desc="Profiling source frames"):
|
570 |
+
|
571 |
+
# # =================== Step 1: Aggregate Confidences and Warps ===================
|
572 |
+
# start = time.time()
|
573 |
+
# viewpoint_cam1 = viewpoint_stack[source_idx]
|
574 |
+
# viewpoint_cam2 = viewpoint_stack[closest_indices_selected[source_idx,0]]
|
575 |
+
# imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
|
576 |
+
# imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
|
577 |
+
# imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
|
578 |
+
# imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
|
579 |
+
# warp, certainty_warp = roma_model.match(imA, imB, device=device)
|
580 |
+
|
581 |
+
# certainties_max, warps_max, certainties_max_idcs, imB_compound, certainties_all, warps_all = aggregate_confidences_and_warps(
|
582 |
+
# viewpoint_stack=viewpoint_stack,
|
583 |
+
# closest_indices=closest_indices_selected,
|
584 |
+
# roma_model=roma_model,
|
585 |
+
# source_idx=source_idx,
|
586 |
+
# verbose=verbose,
|
587 |
+
# output_dict=visualizations
|
588 |
+
# )
|
589 |
+
|
590 |
+
# certainties_max = certainty_warp
|
591 |
+
# with torch.no_grad():
|
592 |
+
# warps_all = warps.unsqueeze(0)
|
593 |
+
|
594 |
+
# timings['aggregation_warp_certainty'].append(time.time() - start)
|
595 |
+
|
596 |
+
# # =================== Step 2: Good Samples Selection ===================
|
597 |
+
# start = time.time()
|
598 |
+
# certainty = certainties_max.reshape(-1).clone()
|
599 |
+
# certainty[certainty > upper_thresh] = 1
|
600 |
+
# good_samples = torch.multinomial(certainty, num_samples=min(expansion_factor * M, len(certainty)), replacement=False)
|
601 |
+
# timings['good_samples_selection'].append(time.time() - start)
|
602 |
+
|
603 |
+
# # =================== Step 3: Triangulate Keypoints for Each NN ===================
|
604 |
+
# reference_image_dict = {
|
605 |
+
# "triangulated_points": [],
|
606 |
+
# "triangulated_points_errors_proj1": [],
|
607 |
+
# "triangulated_points_errors_proj2": []
|
608 |
+
# }
|
609 |
+
|
610 |
+
# start = time.time()
|
611 |
+
# for NN_idx in range(len(warps_all)):
|
612 |
+
# matches_NN = warps_all[NN_idx].reshape(-1, 4)[good_samples]
|
613 |
+
|
614 |
+
# # Extract keypoints and colors
|
615 |
+
# kptsA_np, kptsB_np, kptsB_proj_matrices_idcs, kptsA_color, kptsB_color = extract_keypoints_and_colors(
|
616 |
+
# imA, imB_compound, certainties_max, certainties_max_idcs, matches_NN, roma_model
|
617 |
+
# )
|
618 |
+
|
619 |
+
# proj_matrices_A = viewpoint_stack[source_idx].full_proj_transform
|
620 |
+
# proj_matrices_B = viewpoint_stack[closest_indices_selected[source_idx, NN_idx]].full_proj_transform
|
621 |
+
# triangulated_points, triangulated_points_errors_proj1, triangulated_points_errors_proj2 = triangulate_points(
|
622 |
+
# P1=torch.stack([proj_matrices_A] * M, axis=0),
|
623 |
+
# P2=torch.stack([proj_matrices_B] * M, axis=0),
|
624 |
+
# k1_x=kptsA_np[:M, 0], k1_y=kptsA_np[:M, 1],
|
625 |
+
# k2_x=kptsB_np[:M, 0], k2_y=kptsB_np[:M, 1])
|
626 |
+
|
627 |
+
# reference_image_dict["triangulated_points"].append(triangulated_points)
|
628 |
+
# reference_image_dict["triangulated_points_errors_proj1"].append(triangulated_points_errors_proj1)
|
629 |
+
# reference_image_dict["triangulated_points_errors_proj2"].append(triangulated_points_errors_proj2)
|
630 |
+
# timings['triangulation_per_NN'].append(time.time() - start)
|
631 |
+
|
632 |
+
# # =================== Step 4: Select Best Triangulated Points ===================
|
633 |
+
# start = time.time()
|
634 |
+
# NNs_triangulated_points_selected, NNs_triangulated_points_selected_proj_errors = select_best_keypoints(
|
635 |
+
# NNs_triangulated_points=torch.stack(reference_image_dict["triangulated_points"], dim=0),
|
636 |
+
# NNs_errors_proj1=np.stack(reference_image_dict["triangulated_points_errors_proj1"], axis=0),
|
637 |
+
# NNs_errors_proj2=np.stack(reference_image_dict["triangulated_points_errors_proj2"], axis=0))
|
638 |
+
# timings['select_best_keypoints'].append(time.time() - start)
|
639 |
+
|
640 |
+
# # =================== Step 5: Create New Gaussians ===================
|
641 |
+
# start = time.time()
|
642 |
+
# viewpoint_cam1 = viewpoint_stack[source_idx]
|
643 |
+
# N = len(NNs_triangulated_points_selected)
|
644 |
+
# new_xyz = NNs_triangulated_points_selected[:, :-1]
|
645 |
+
# all_new_xyz.append(new_xyz)
|
646 |
+
# all_new_features_dc.append(RGB2SH(torch.tensor(kptsA_color.astype(np.float32) / 255.)).unsqueeze(1))
|
647 |
+
# all_new_features_rest.append(torch.stack([gaussians._features_rest[-1].clone().detach() * 0.] * N, dim=0))
|
648 |
+
|
649 |
+
# mask_bad_points = torch.tensor(
|
650 |
+
# NNs_triangulated_points_selected_proj_errors > keypoint_fit_error_tolerance,
|
651 |
+
# dtype=torch.float32).unsqueeze(1).to(device)
|
652 |
+
|
653 |
+
# all_new_opacities.append(torch.stack([gaussians._opacity[-1].clone().detach()] * N, dim=0) * 0. - mask_bad_points * (1e1))
|
654 |
+
|
655 |
+
# dist_points_to_cam1 = torch.linalg.norm(viewpoint_cam1.camera_center.clone().detach() - new_xyz, dim=1, ord=2)
|
656 |
+
# all_new_scaling.append(gaussians.scaling_inverse_activation((dist_points_to_cam1 * scaling_factor).unsqueeze(1).repeat(1, 3)))
|
657 |
+
# all_new_rotation.append(torch.stack([gaussians._rotation[-1].clone().detach()] * N, dim=0))
|
658 |
+
# timings['save_gaussians'].append(time.time() - start)
|
659 |
+
|
660 |
+
# # =================== Final Densification Postfix ===================
|
661 |
+
# start = time.time()
|
662 |
+
# all_new_xyz = torch.cat(all_new_xyz, dim=0)
|
663 |
+
# all_new_features_dc = torch.cat(all_new_features_dc, dim=0)
|
664 |
+
# new_tmp_radii = torch.zeros(all_new_xyz.shape[0])
|
665 |
+
# prune_mask = torch.ones(all_new_xyz.shape[0], dtype=torch.bool)
|
666 |
+
|
667 |
+
# gaussians.densification_postfix(
|
668 |
+
# all_new_xyz[prune_mask].to(device),
|
669 |
+
# all_new_features_dc[prune_mask].to(device),
|
670 |
+
# torch.cat(all_new_features_rest, dim=0)[prune_mask].to(device),
|
671 |
+
# torch.cat(all_new_opacities, dim=0)[prune_mask].to(device),
|
672 |
+
# torch.cat(all_new_scaling, dim=0)[prune_mask].to(device),
|
673 |
+
# torch.cat(all_new_rotation, dim=0)[prune_mask].to(device),
|
674 |
+
# new_tmp_radii[prune_mask].to(device)
|
675 |
+
# )
|
676 |
+
# timings['final_densification_postfix'].append(time.time() - start)
|
677 |
+
|
678 |
+
# # =================== Print Profiling Results ===================
|
679 |
+
# print("\n=== Profiling Summary (average per frame) ===")
|
680 |
+
# for key, times in timings.items():
|
681 |
+
# print(f"{key:35s}: {sum(times) / len(times):.4f} sec (total {sum(times):.2f} sec)")
|
682 |
+
|
683 |
+
# return viewpoint_stack, closest_indices_selected, visualizations
|
684 |
+
|
685 |
+
|
686 |
+
|
687 |
+
def extract_keypoints_and_colors_single(imA, imB, matches, roma_model, verbose=False, output_dict={}):
|
688 |
+
"""
|
689 |
+
Extracts keypoints and corresponding colors from a source image (imA) and a single target image (imB).
|
690 |
+
|
691 |
+
Args:
|
692 |
+
imA: Source image as a NumPy array (H_A, W_A, C).
|
693 |
+
imB: Target image as a NumPy array (H_B, W_B, C).
|
694 |
+
matches: Matches in normalized coordinates (torch.Tensor).
|
695 |
+
roma_model: Roma model instance for keypoint operations.
|
696 |
+
verbose: If True, outputs intermediate visualizations.
|
697 |
+
Returns:
|
698 |
+
kptsA_np: Keypoints in imA (normalized).
|
699 |
+
kptsB_np: Keypoints in imB (normalized).
|
700 |
+
kptsA_color: Colors of keypoints in imA.
|
701 |
+
kptsB_color: Colors of keypoints in imB.
|
702 |
+
"""
|
703 |
+
H_A, W_A, _ = imA.shape
|
704 |
+
H_B, W_B, _ = imB.shape
|
705 |
+
|
706 |
+
# Convert matches to pixel coordinates
|
707 |
+
# Matches format: (B, 4) = (x1_norm, y1_norm, x2_norm, y2_norm)
|
708 |
+
kptsA = matches[:, :2] # [N, 2]
|
709 |
+
kptsB = matches[:, 2:] # [N, 2]
|
710 |
+
|
711 |
+
# Scale normalized coordinates [-1,1] to pixel coordinates
|
712 |
+
kptsA_pix = torch.zeros_like(kptsA)
|
713 |
+
kptsB_pix = torch.zeros_like(kptsB)
|
714 |
+
|
715 |
+
# Important! [Normalized to pixel space]
|
716 |
+
kptsA_pix[:, 0] = (kptsA[:, 0] + 1) * (W_A - 1) / 2
|
717 |
+
kptsA_pix[:, 1] = (kptsA[:, 1] + 1) * (H_A - 1) / 2
|
718 |
+
|
719 |
+
kptsB_pix[:, 0] = (kptsB[:, 0] + 1) * (W_B - 1) / 2
|
720 |
+
kptsB_pix[:, 1] = (kptsB[:, 1] + 1) * (H_B - 1) / 2
|
721 |
+
|
722 |
+
kptsA_np = kptsA_pix.detach().cpu().numpy()
|
723 |
+
kptsB_np = kptsB_pix.detach().cpu().numpy()
|
724 |
+
|
725 |
+
# Extract colors
|
726 |
+
kptsA_x = np.round(kptsA_np[:, 0]).astype(int)
|
727 |
+
kptsA_y = np.round(kptsA_np[:, 1]).astype(int)
|
728 |
+
kptsB_x = np.round(kptsB_np[:, 0]).astype(int)
|
729 |
+
kptsB_y = np.round(kptsB_np[:, 1]).astype(int)
|
730 |
+
|
731 |
+
kptsA_color = imA[np.clip(kptsA_y, 0, H_A-1), np.clip(kptsA_x, 0, W_A-1)]
|
732 |
+
kptsB_color = imB[np.clip(kptsB_y, 0, H_B-1), np.clip(kptsB_x, 0, W_B-1)]
|
733 |
+
|
734 |
+
# Normalize keypoints into [-1, 1] for downstream triangulation
|
735 |
+
kptsA_np_norm = np.zeros_like(kptsA_np)
|
736 |
+
kptsB_np_norm = np.zeros_like(kptsB_np)
|
737 |
+
|
738 |
+
kptsA_np_norm[:, 0] = kptsA_np[:, 0] / (W_A - 1) * 2.0 - 1.0
|
739 |
+
kptsA_np_norm[:, 1] = kptsA_np[:, 1] / (H_A - 1) * 2.0 - 1.0
|
740 |
+
|
741 |
+
kptsB_np_norm[:, 0] = kptsB_np[:, 0] / (W_B - 1) * 2.0 - 1.0
|
742 |
+
kptsB_np_norm[:, 1] = kptsB_np[:, 1] / (H_B - 1) * 2.0 - 1.0
|
743 |
+
|
744 |
+
return kptsA_np_norm, kptsB_np_norm, kptsA_color, kptsB_color
|
745 |
+
|
746 |
+
|
747 |
+
|
748 |
+
def init_gaussians_with_corr_profiled(gaussians, scene, cfg, device, verbose=False, roma_model=None):
|
749 |
+
timings = defaultdict(list)
|
750 |
+
|
751 |
+
if roma_model is None:
|
752 |
+
if cfg.roma_model == "indoors":
|
753 |
+
roma_model = roma_indoor(device=device)
|
754 |
+
else:
|
755 |
+
roma_model = roma_outdoor(device=device)
|
756 |
+
roma_model.upsample_preds = False
|
757 |
+
roma_model.symmetric = False
|
758 |
+
|
759 |
+
M = cfg.matches_per_ref
|
760 |
+
upper_thresh = roma_model.sample_thresh
|
761 |
+
scaling_factor = cfg.scaling_factor
|
762 |
+
expansion_factor = 1
|
763 |
+
keypoint_fit_error_tolerance = cfg.proj_err_tolerance
|
764 |
+
visualizations = {}
|
765 |
+
viewpoint_stack = scene.getTrainCameras().copy()
|
766 |
+
NUM_REFERENCE_FRAMES = min(cfg.num_refs, len(viewpoint_stack))
|
767 |
+
NUM_NNS_PER_REFERENCE = 1 # Only ONE neighbor now!
|
768 |
+
|
769 |
+
viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
|
770 |
+
|
771 |
+
selected_indices = select_cameras_kmeans(cameras=viewpoint_cam_all.detach().cpu().numpy(), K=NUM_REFERENCE_FRAMES)
|
772 |
+
selected_indices = sorted(selected_indices)
|
773 |
+
|
774 |
+
viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
|
775 |
+
closest_indices = k_closest_vectors(viewpoint_cam_all, NUM_NNS_PER_REFERENCE)
|
776 |
+
closest_indices_selected = closest_indices[:, :].detach().cpu().numpy()
|
777 |
+
|
778 |
+
all_new_xyz = []
|
779 |
+
all_new_features_dc = []
|
780 |
+
all_new_features_rest = []
|
781 |
+
all_new_opacities = []
|
782 |
+
all_new_scaling = []
|
783 |
+
all_new_rotation = []
|
784 |
+
|
785 |
+
# Dummy first pass to initialize model
|
786 |
+
with torch.no_grad():
|
787 |
+
viewpoint_cam1 = viewpoint_stack[0]
|
788 |
+
viewpoint_cam2 = viewpoint_stack[1]
|
789 |
+
imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
|
790 |
+
imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
|
791 |
+
imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
|
792 |
+
imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
|
793 |
+
warp, certainty_warp = roma_model.match(imA, imB, device=device)
|
794 |
+
del warp, certainty_warp
|
795 |
+
torch.cuda.empty_cache()
|
796 |
+
|
797 |
+
# Main Loop over source_idx
|
798 |
+
for source_idx in tqdm(sorted(selected_indices), desc="Profiling source frames"):
|
799 |
+
|
800 |
+
# =================== Step 1: Compute Warp and Certainty ===================
|
801 |
+
start = time.time()
|
802 |
+
viewpoint_cam1 = viewpoint_stack[source_idx]
|
803 |
+
NNs=closest_indices_selected.shape[1]
|
804 |
+
viewpoint_cam2 = viewpoint_stack[closest_indices_selected[source_idx, np.random.randint(NNs)]]
|
805 |
+
imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
|
806 |
+
imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
|
807 |
+
imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
|
808 |
+
imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
|
809 |
+
warp, certainty_warp = roma_model.match(imA, imB, device=device)
|
810 |
+
|
811 |
+
certainties_max = certainty_warp # New manual sampling
|
812 |
+
timings['aggregation_warp_certainty'].append(time.time() - start)
|
813 |
+
|
814 |
+
# =================== Step 2: Good Samples Selection ===================
|
815 |
+
start = time.time()
|
816 |
+
certainty = certainties_max.reshape(-1).clone()
|
817 |
+
certainty[certainty > upper_thresh] = 1
|
818 |
+
good_samples = torch.multinomial(certainty, num_samples=min(expansion_factor * M, len(certainty)), replacement=False)
|
819 |
+
timings['good_samples_selection'].append(time.time() - start)
|
820 |
+
|
821 |
+
# =================== Step 3: Triangulate Keypoints ===================
|
822 |
+
reference_image_dict = {
|
823 |
+
"triangulated_points": [],
|
824 |
+
"triangulated_points_errors_proj1": [],
|
825 |
+
"triangulated_points_errors_proj2": []
|
826 |
+
}
|
827 |
+
|
828 |
+
start = time.time()
|
829 |
+
matches_NN = warp.reshape(-1, 4)[good_samples]
|
830 |
+
|
831 |
+
# Convert matches to pixel coordinates
|
832 |
+
kptsA_np, kptsB_np, kptsA_color, kptsB_color = extract_keypoints_and_colors_single(
|
833 |
+
np.array(imA).astype(np.uint8),
|
834 |
+
np.array(imB).astype(np.uint8),
|
835 |
+
matches_NN,
|
836 |
+
roma_model
|
837 |
+
)
|
838 |
+
|
839 |
+
proj_matrices_A = viewpoint_stack[source_idx].full_proj_transform
|
840 |
+
proj_matrices_B = viewpoint_stack[closest_indices_selected[source_idx, 0]].full_proj_transform
|
841 |
+
|
842 |
+
triangulated_points, triangulated_points_errors_proj1, triangulated_points_errors_proj2 = triangulate_points(
|
843 |
+
P1=torch.stack([proj_matrices_A] * M, axis=0),
|
844 |
+
P2=torch.stack([proj_matrices_B] * M, axis=0),
|
845 |
+
k1_x=kptsA_np[:M, 0], k1_y=kptsA_np[:M, 1],
|
846 |
+
k2_x=kptsB_np[:M, 0], k2_y=kptsB_np[:M, 1])
|
847 |
+
|
848 |
+
reference_image_dict["triangulated_points"].append(triangulated_points)
|
849 |
+
reference_image_dict["triangulated_points_errors_proj1"].append(triangulated_points_errors_proj1)
|
850 |
+
reference_image_dict["triangulated_points_errors_proj2"].append(triangulated_points_errors_proj2)
|
851 |
+
timings['triangulation_per_NN'].append(time.time() - start)
|
852 |
+
|
853 |
+
# =================== Step 4: Select Best Triangulated Points ===================
|
854 |
+
start = time.time()
|
855 |
+
NNs_triangulated_points_selected, NNs_triangulated_points_selected_proj_errors = select_best_keypoints(
|
856 |
+
NNs_triangulated_points=torch.stack(reference_image_dict["triangulated_points"], dim=0),
|
857 |
+
NNs_errors_proj1=np.stack(reference_image_dict["triangulated_points_errors_proj1"], axis=0),
|
858 |
+
NNs_errors_proj2=np.stack(reference_image_dict["triangulated_points_errors_proj2"], axis=0))
|
859 |
+
timings['select_best_keypoints'].append(time.time() - start)
|
860 |
+
|
861 |
+
# =================== Step 5: Create New Gaussians ===================
|
862 |
+
start = time.time()
|
863 |
+
viewpoint_cam1 = viewpoint_stack[source_idx]
|
864 |
+
N = len(NNs_triangulated_points_selected)
|
865 |
+
new_xyz = NNs_triangulated_points_selected[:, :-1]
|
866 |
+
all_new_xyz.append(new_xyz)
|
867 |
+
all_new_features_dc.append(RGB2SH(torch.tensor(kptsA_color.astype(np.float32) / 255.)).unsqueeze(1))
|
868 |
+
all_new_features_rest.append(torch.stack([gaussians._features_rest[-1].clone().detach() * 0.] * N, dim=0))
|
869 |
+
|
870 |
+
mask_bad_points = torch.tensor(
|
871 |
+
NNs_triangulated_points_selected_proj_errors > keypoint_fit_error_tolerance,
|
872 |
+
dtype=torch.float32).unsqueeze(1).to(device)
|
873 |
+
|
874 |
+
all_new_opacities.append(torch.stack([gaussians._opacity[-1].clone().detach()] * N, dim=0) * 0. - mask_bad_points * (1e1))
|
875 |
+
|
876 |
+
dist_points_to_cam1 = torch.linalg.norm(viewpoint_cam1.camera_center.clone().detach() - new_xyz, dim=1, ord=2)
|
877 |
+
all_new_scaling.append(gaussians.scaling_inverse_activation((dist_points_to_cam1 * scaling_factor).unsqueeze(1).repeat(1, 3)))
|
878 |
+
all_new_rotation.append(torch.stack([gaussians._rotation[-1].clone().detach()] * N, dim=0))
|
879 |
+
timings['save_gaussians'].append(time.time() - start)
|
880 |
+
|
881 |
+
# =================== Final Densification Postfix ===================
|
882 |
+
start = time.time()
|
883 |
+
all_new_xyz = torch.cat(all_new_xyz, dim=0)
|
884 |
+
all_new_features_dc = torch.cat(all_new_features_dc, dim=0)
|
885 |
+
new_tmp_radii = torch.zeros(all_new_xyz.shape[0])
|
886 |
+
prune_mask = torch.ones(all_new_xyz.shape[0], dtype=torch.bool)
|
887 |
+
|
888 |
+
gaussians.densification_postfix(
|
889 |
+
all_new_xyz[prune_mask].to(device),
|
890 |
+
all_new_features_dc[prune_mask].to(device),
|
891 |
+
torch.cat(all_new_features_rest, dim=0)[prune_mask].to(device),
|
892 |
+
torch.cat(all_new_opacities, dim=0)[prune_mask].to(device),
|
893 |
+
torch.cat(all_new_scaling, dim=0)[prune_mask].to(device),
|
894 |
+
torch.cat(all_new_rotation, dim=0)[prune_mask].to(device),
|
895 |
+
new_tmp_radii[prune_mask].to(device)
|
896 |
+
)
|
897 |
+
timings['final_densification_postfix'].append(time.time() - start)
|
898 |
+
|
899 |
+
# =================== Print Profiling Results ===================
|
900 |
+
print("\n=== Profiling Summary (average per frame) ===")
|
901 |
+
for key, times in timings.items():
|
902 |
+
print(f"{key:35s}: {sum(times) / len(times):.4f} sec (total {sum(times):.2f} sec)")
|
903 |
+
|
904 |
+
return viewpoint_stack, closest_indices_selected, visualizations
|
source/data_utils.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def scene_cameras_train_test_split(scene, verbose=False):
|
2 |
+
"""
|
3 |
+
Iterate over resolutions in the scene. For each resolution check if this resolution has test_cameras
|
4 |
+
if it doesn't then extract every 8th camera from the train and put it to the test set. This follows the
|
5 |
+
evaluation protocol suggested by Kerbl et al. in the seminal work on 3DGS. All changes to the input
|
6 |
+
object scene are inplace changes.
|
7 |
+
:param scene: Scene Class object from the gaussian-splatting.scene module
|
8 |
+
:param verbose: Print initial and final stage of the function
|
9 |
+
:return: None
|
10 |
+
|
11 |
+
"""
|
12 |
+
if verbose: print("Preparing train and test sets split...")
|
13 |
+
for resolution in scene.train_cameras.keys():
|
14 |
+
if len(scene.test_cameras[resolution]) == 0:
|
15 |
+
if verbose:
|
16 |
+
print(f"Found no test_cameras for resolution {resolution}. Move every 8th camera out ouf total "+\
|
17 |
+
f"{len(scene.train_cameras[resolution])} train cameras to the test set now")
|
18 |
+
N = len(scene.train_cameras[resolution])
|
19 |
+
scene.test_cameras[resolution] = [scene.train_cameras[resolution][idx] for idx in range(0, N)
|
20 |
+
if idx % 8 == 0]
|
21 |
+
scene.train_cameras[resolution] = [scene.train_cameras[resolution][idx] for idx in range(0, N)
|
22 |
+
if idx % 8 != 0]
|
23 |
+
if verbose:
|
24 |
+
print(f"Done. Now train and test sets contain each {len(scene.train_cameras[resolution])} and " + \
|
25 |
+
f"{len(scene.test_cameras[resolution])} cameras respectively.")
|
26 |
+
|
27 |
+
|
28 |
+
return
|
source/losses.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code is copied from the gaussian-splatting/utils/loss_utils.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.autograd import Variable
|
6 |
+
from math import exp
|
7 |
+
|
8 |
+
def l1_loss(network_output, gt, mean=True):
|
9 |
+
return torch.abs((network_output - gt)).mean() if mean else torch.abs((network_output - gt))
|
10 |
+
|
11 |
+
def l2_loss(network_output, gt):
|
12 |
+
return ((network_output - gt) ** 2).mean()
|
13 |
+
|
14 |
+
def gaussian(window_size, sigma):
|
15 |
+
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
|
16 |
+
return gauss / gauss.sum()
|
17 |
+
|
18 |
+
def create_window(window_size, channel):
|
19 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
20 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
21 |
+
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
22 |
+
return window
|
23 |
+
|
24 |
+
def ssim(img1, img2, window_size=11, size_average=True, mask = None):
|
25 |
+
channel = img1.size(-3)
|
26 |
+
window = create_window(window_size, channel)
|
27 |
+
|
28 |
+
if img1.is_cuda:
|
29 |
+
window = window.cuda(img1.get_device())
|
30 |
+
window = window.type_as(img1)
|
31 |
+
|
32 |
+
return _ssim(img1, img2, window, window_size, channel, size_average, mask)
|
33 |
+
|
34 |
+
def _ssim(img1, img2, window, window_size, channel, size_average=True, mask = None):
|
35 |
+
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
36 |
+
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
37 |
+
|
38 |
+
mu1_sq = mu1.pow(2)
|
39 |
+
mu2_sq = mu2.pow(2)
|
40 |
+
mu1_mu2 = mu1 * mu2
|
41 |
+
|
42 |
+
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
43 |
+
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
44 |
+
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
45 |
+
|
46 |
+
C1 = 0.01 ** 2
|
47 |
+
C2 = 0.03 ** 2
|
48 |
+
|
49 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
50 |
+
|
51 |
+
if mask is not None:
|
52 |
+
ssim_map = ssim_map * mask
|
53 |
+
|
54 |
+
if size_average:
|
55 |
+
return ssim_map.mean()
|
56 |
+
else:
|
57 |
+
return ssim_map.mean(1).mean(1).mean(1)
|
58 |
+
|
59 |
+
|
60 |
+
def mse(img1, img2):
|
61 |
+
return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
|
62 |
+
|
63 |
+
def psnr(img1, img2):
|
64 |
+
"""
|
65 |
+
Computes the Peak Signal-to-Noise Ratio (PSNR) between two single images. NOT BATCHED!
|
66 |
+
Args:
|
67 |
+
img1 (torch.Tensor): The first image tensor, with pixel values scaled between 0 and 1.
|
68 |
+
Shape should be (channels, height, width).
|
69 |
+
img2 (torch.Tensor): The second image tensor with the same shape as img1, used for comparison.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
torch.Tensor: A scalar tensor containing the PSNR value in decibels (dB).
|
73 |
+
"""
|
74 |
+
mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
|
75 |
+
return 20 * torch.log10(1.0 / torch.sqrt(mse))
|
76 |
+
|
77 |
+
|
78 |
+
def tv_loss(image):
|
79 |
+
"""
|
80 |
+
Computes the total variation (TV) loss for an image of shape [3, H, W].
|
81 |
+
|
82 |
+
Args:
|
83 |
+
image (torch.Tensor): Input image of shape [3, H, W]
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
torch.Tensor: Scalar value representing the total variation loss.
|
87 |
+
"""
|
88 |
+
# Ensure the image has the correct dimensions
|
89 |
+
assert image.ndim == 3 and image.shape[0] == 3, "Input must be of shape [3, H, W]"
|
90 |
+
|
91 |
+
# Compute the difference between adjacent pixels in the x-direction (width)
|
92 |
+
diff_x = torch.abs(image[:, :, 1:] - image[:, :, :-1])
|
93 |
+
|
94 |
+
# Compute the difference between adjacent pixels in the y-direction (height)
|
95 |
+
diff_y = torch.abs(image[:, 1:, :] - image[:, :-1, :])
|
96 |
+
|
97 |
+
# Sum the total variation in both directions
|
98 |
+
tv_loss_value = torch.mean(diff_x) + torch.mean(diff_y)
|
99 |
+
|
100 |
+
return tv_loss_value
|
source/networks.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import sys
|
4 |
+
sys.path.append('./submodules/gaussian-splatting/')
|
5 |
+
|
6 |
+
from random import randint
|
7 |
+
from scene import Scene, GaussianModel
|
8 |
+
from gaussian_renderer import render
|
9 |
+
from source.data_utils import scene_cameras_train_test_split
|
10 |
+
|
11 |
+
class Warper3DGS(torch.nn.Module):
|
12 |
+
def __init__(self, sh_degree, opt, pipe, dataset, viewpoint_stack, verbose,
|
13 |
+
do_train_test_split=True):
|
14 |
+
super(Warper3DGS, self).__init__()
|
15 |
+
"""
|
16 |
+
Init Warper using all the objects necessary for rendering gaussian splats.
|
17 |
+
Here we merely link class objects to the objects instantiated outsided the class.
|
18 |
+
"""
|
19 |
+
print("ready!!!7")
|
20 |
+
self.gaussians = GaussianModel(sh_degree)
|
21 |
+
print("ready!!!8")
|
22 |
+
self.gaussians.tmp_radii = torch.zeros((self.gaussians.get_xyz.shape[0]), device="cuda")
|
23 |
+
self.render = render
|
24 |
+
self.gs_config_opt = opt
|
25 |
+
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
|
26 |
+
self.bg = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
|
27 |
+
self.pipe = pipe
|
28 |
+
print("ready!!!")
|
29 |
+
self.scene = Scene(dataset, self.gaussians, shuffle=False)
|
30 |
+
print("ready2")
|
31 |
+
if do_train_test_split:
|
32 |
+
scene_cameras_train_test_split(self.scene, verbose=verbose)
|
33 |
+
|
34 |
+
self.gaussians.training_setup(opt)
|
35 |
+
self.viewpoint_stack = viewpoint_stack
|
36 |
+
if not self.viewpoint_stack:
|
37 |
+
self.viewpoint_stack = self.scene.getTrainCameras().copy()
|
38 |
+
|
39 |
+
def forward(self, viewpoint_cam=None):
|
40 |
+
"""
|
41 |
+
For a provided camera viewpoint_cam we render gaussians from this viewpoint.
|
42 |
+
If no camera provided then we use the self.viewpoint_stack (list of cameras).
|
43 |
+
If the latter is empty we reinitialize it using the self.scene object.
|
44 |
+
"""
|
45 |
+
if not viewpoint_cam:
|
46 |
+
if not self.viewpoint_stack:
|
47 |
+
self.viewpoint_stack = self.scene.getTrainCameras().copy()
|
48 |
+
viewpoint_cam = self.viewpoint_stack[randint(0, len(self.viewpoint_stack) - 1)]
|
49 |
+
|
50 |
+
render_pkg = self.render(viewpoint_cam, self.gaussians, self.pipe, self.bg)
|
51 |
+
return render_pkg
|
52 |
+
|
source/timer.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
class Timer:
|
3 |
+
def __init__(self):
|
4 |
+
self.start_time = None
|
5 |
+
self.elapsed = 0
|
6 |
+
self.paused = False
|
7 |
+
|
8 |
+
def start(self):
|
9 |
+
if self.start_time is None:
|
10 |
+
self.start_time = time.time()
|
11 |
+
elif self.paused:
|
12 |
+
self.start_time = time.time() - self.elapsed
|
13 |
+
self.paused = False
|
14 |
+
|
15 |
+
def pause(self):
|
16 |
+
if not self.paused:
|
17 |
+
self.elapsed = time.time() - self.start_time
|
18 |
+
self.paused = True
|
19 |
+
|
20 |
+
def get_elapsed_time(self):
|
21 |
+
if self.paused:
|
22 |
+
return self.elapsed
|
23 |
+
else:
|
24 |
+
return time.time() - self.start_time
|
source/trainer.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from random import randint
|
3 |
+
from tqdm.rich import trange
|
4 |
+
from tqdm import tqdm as tqdm
|
5 |
+
from source.networks import Warper3DGS
|
6 |
+
import wandb
|
7 |
+
import sys
|
8 |
+
|
9 |
+
sys.path.append('./submodules/gaussian-splatting/')
|
10 |
+
import lpips
|
11 |
+
from source.losses import ssim, l1_loss, psnr
|
12 |
+
from rich.console import Console
|
13 |
+
from rich.theme import Theme
|
14 |
+
|
15 |
+
custom_theme = Theme({
|
16 |
+
"info": "dim cyan",
|
17 |
+
"warning": "magenta",
|
18 |
+
"danger": "bold red"
|
19 |
+
})
|
20 |
+
|
21 |
+
#from source.corr_init import init_gaussians_with_corr
|
22 |
+
from source.corr_init_new import init_gaussians_with_corr_profiled as init_gaussians_with_corr
|
23 |
+
from source.utils_aux import log_samples
|
24 |
+
|
25 |
+
from source.timer import Timer
|
26 |
+
|
27 |
+
class EDGSTrainer:
|
28 |
+
def __init__(self,
|
29 |
+
GS: Warper3DGS,
|
30 |
+
training_config,
|
31 |
+
dataset_white_background=False,
|
32 |
+
device=torch.device('cuda'),
|
33 |
+
log_wandb=True,
|
34 |
+
):
|
35 |
+
self.GS = GS
|
36 |
+
self.scene = GS.scene
|
37 |
+
self.viewpoint_stack = GS.viewpoint_stack
|
38 |
+
self.gaussians = GS.gaussians
|
39 |
+
|
40 |
+
self.training_config = training_config
|
41 |
+
self.GS_optimizer = GS.gaussians.optimizer
|
42 |
+
self.dataset_white_background = dataset_white_background
|
43 |
+
|
44 |
+
self.training_step = 1
|
45 |
+
self.gs_step = 0
|
46 |
+
self.CONSOLE = Console(width=120, theme=custom_theme)
|
47 |
+
self.saving_iterations = training_config.save_iterations
|
48 |
+
self.evaluate_iterations = None
|
49 |
+
self.batch_size = training_config.batch_size
|
50 |
+
self.ema_loss_for_log = 0.0
|
51 |
+
|
52 |
+
# Logs in the format {step:{"loss1":loss1_value, "loss2":loss2_value}}
|
53 |
+
self.logs_losses = {}
|
54 |
+
self.lpips = lpips.LPIPS(net='vgg').to(device)
|
55 |
+
self.device = device
|
56 |
+
self.timer = Timer()
|
57 |
+
self.log_wandb = log_wandb
|
58 |
+
|
59 |
+
def load_checkpoints(self, load_cfg):
|
60 |
+
# Load 3DGS checkpoint
|
61 |
+
if load_cfg.gs:
|
62 |
+
self.gs.gaussians.restore(
|
63 |
+
torch.load(f"{load_cfg.gs}/chkpnt{load_cfg.gs_step}.pth")[0],
|
64 |
+
self.training_config)
|
65 |
+
self.GS_optimizer = self.GS.gaussians.optimizer
|
66 |
+
self.CONSOLE.print(f"3DGS loaded from checkpoint for iteration {load_cfg.gs_step}",
|
67 |
+
style="info")
|
68 |
+
self.training_step += load_cfg.gs_step
|
69 |
+
self.gs_step += load_cfg.gs_step
|
70 |
+
|
71 |
+
def train(self, train_cfg):
|
72 |
+
# 3DGS training
|
73 |
+
self.CONSOLE.print("Train 3DGS for {} iterations".format(train_cfg.gs_epochs), style="info")
|
74 |
+
with trange(self.training_step, self.training_step + train_cfg.gs_epochs, desc="[green]Train gaussians") as progress_bar:
|
75 |
+
for self.training_step in progress_bar:
|
76 |
+
radii = self.train_step_gs(max_lr=train_cfg.max_lr, no_densify=train_cfg.no_densify)
|
77 |
+
with torch.no_grad():
|
78 |
+
if train_cfg.no_densify:
|
79 |
+
self.prune(radii)
|
80 |
+
else:
|
81 |
+
self.densify_and_prune(radii)
|
82 |
+
if train_cfg.reduce_opacity:
|
83 |
+
# Slightly reduce opacity every few steps:
|
84 |
+
if self.gs_step < self.training_config.densify_until_iter and self.gs_step % 10 == 0:
|
85 |
+
opacities_new = torch.log(torch.exp(self.GS.gaussians._opacity.data) * 0.99)
|
86 |
+
self.GS.gaussians._opacity.data = opacities_new
|
87 |
+
self.timer.pause()
|
88 |
+
# Progress bar
|
89 |
+
if self.training_step % 10 == 0:
|
90 |
+
progress_bar.set_postfix({"[red]Loss": f"{self.ema_loss_for_log:.{7}f}"}, refresh=True)
|
91 |
+
# Log and save
|
92 |
+
if self.training_step in self.saving_iterations:
|
93 |
+
self.save_model()
|
94 |
+
if self.evaluate_iterations is not None:
|
95 |
+
if self.training_step in self.evaluate_iterations:
|
96 |
+
self.evaluate()
|
97 |
+
else:
|
98 |
+
if (self.training_step <= 3000 and self.training_step % 500 == 0) or \
|
99 |
+
(self.training_step > 3000 and self.training_step % 1000 == 228) :
|
100 |
+
self.evaluate()
|
101 |
+
|
102 |
+
self.timer.start()
|
103 |
+
|
104 |
+
|
105 |
+
def evaluate(self):
|
106 |
+
torch.cuda.empty_cache()
|
107 |
+
log_gen_images, log_real_images = [], []
|
108 |
+
validation_configs = ({'name': 'test', 'cameras': self.scene.getTestCameras(), 'cam_idx': self.training_config.TEST_CAM_IDX_TO_LOG},
|
109 |
+
{'name': 'train',
|
110 |
+
'cameras': [self.scene.getTrainCameras()[idx % len(self.scene.getTrainCameras())] for idx in
|
111 |
+
range(0, 150, 5)], 'cam_idx': 10})
|
112 |
+
if self.log_wandb:
|
113 |
+
wandb.log({f"Number of Gaussians": len(self.GS.gaussians._xyz)}, step=self.training_step)
|
114 |
+
for config in validation_configs:
|
115 |
+
if config['cameras'] and len(config['cameras']) > 0:
|
116 |
+
l1_test = 0.0
|
117 |
+
psnr_test = 0.0
|
118 |
+
ssim_test = 0.0
|
119 |
+
lpips_splat_test = 0.0
|
120 |
+
for idx, viewpoint in enumerate(config['cameras']):
|
121 |
+
image = torch.clamp(self.GS(viewpoint)["render"], 0.0, 1.0)
|
122 |
+
gt_image = torch.clamp(viewpoint.original_image.to(self.device), 0.0, 1.0)
|
123 |
+
l1_test += l1_loss(image, gt_image).double()
|
124 |
+
psnr_test += psnr(image.unsqueeze(0), gt_image.unsqueeze(0)).double()
|
125 |
+
ssim_test += ssim(image, gt_image).double()
|
126 |
+
lpips_splat_test += self.lpips(image, gt_image).detach().double()
|
127 |
+
if idx in [config['cam_idx']]:
|
128 |
+
log_gen_images.append(image)
|
129 |
+
log_real_images.append(gt_image)
|
130 |
+
psnr_test /= len(config['cameras'])
|
131 |
+
l1_test /= len(config['cameras'])
|
132 |
+
ssim_test /= len(config['cameras'])
|
133 |
+
lpips_splat_test /= len(config['cameras'])
|
134 |
+
if self.log_wandb:
|
135 |
+
wandb.log({f"{config['name']}/L1": l1_test.item(), f"{config['name']}/PSNR": psnr_test.item(), \
|
136 |
+
f"{config['name']}/SSIM": ssim_test.item(), f"{config['name']}/LPIPS_splat": lpips_splat_test.item()}, step = self.training_step)
|
137 |
+
self.CONSOLE.print("\n[ITER {}], #{} gaussians, Evaluating {}: L1={:.6f}, PSNR={:.6f}, SSIM={:.6f}, LPIPS_splat={:.6f} ".format(
|
138 |
+
self.training_step, len(self.GS.gaussians._xyz), config['name'], l1_test.item(), psnr_test.item(), ssim_test.item(), lpips_splat_test.item()), style="info")
|
139 |
+
if self.log_wandb:
|
140 |
+
with torch.no_grad():
|
141 |
+
log_samples(torch.stack((log_real_images[0],log_gen_images[0])) , [], self.training_step, caption="Real and Generated Samples")
|
142 |
+
wandb.log({"time": self.timer.get_elapsed_time()}, step=self.training_step)
|
143 |
+
torch.cuda.empty_cache()
|
144 |
+
|
145 |
+
def train_step_gs(self, max_lr = False, no_densify = False):
|
146 |
+
self.gs_step += 1
|
147 |
+
if max_lr:
|
148 |
+
self.GS.gaussians.update_learning_rate(max(self.gs_step, 8_000))
|
149 |
+
else:
|
150 |
+
self.GS.gaussians.update_learning_rate(self.gs_step)
|
151 |
+
# Every 1000 its we increase the levels of SH up to a maximum degree
|
152 |
+
if self.gs_step % 1000 == 0:
|
153 |
+
self.GS.gaussians.oneupSHdegree()
|
154 |
+
|
155 |
+
# Pick a random Camera
|
156 |
+
if not self.viewpoint_stack:
|
157 |
+
self.viewpoint_stack = self.scene.getTrainCameras().copy()
|
158 |
+
viewpoint_cam = self.viewpoint_stack.pop(randint(0, len(self.viewpoint_stack) - 1))
|
159 |
+
|
160 |
+
render_pkg = self.GS(viewpoint_cam=viewpoint_cam)
|
161 |
+
image = render_pkg["render"]
|
162 |
+
# Loss
|
163 |
+
gt_image = viewpoint_cam.original_image.to(self.device)
|
164 |
+
L1_loss = l1_loss(image, gt_image)
|
165 |
+
|
166 |
+
ssim_loss = (1.0 - ssim(image, gt_image))
|
167 |
+
loss = (1.0 - self.training_config.lambda_dssim) * L1_loss + \
|
168 |
+
self.training_config.lambda_dssim * ssim_loss
|
169 |
+
self.timer.pause()
|
170 |
+
self.logs_losses[self.training_step] = {"loss": loss.item(),
|
171 |
+
"L1_loss": L1_loss.item(),
|
172 |
+
"ssim_loss": ssim_loss.item()}
|
173 |
+
|
174 |
+
if self.log_wandb:
|
175 |
+
for k, v in self.logs_losses[self.training_step].items():
|
176 |
+
wandb.log({f"train/{k}": v}, step=self.training_step)
|
177 |
+
self.ema_loss_for_log = 0.4 * self.logs_losses[self.training_step]["loss"] + 0.6 * self.ema_loss_for_log
|
178 |
+
self.timer.start()
|
179 |
+
self.GS_optimizer.zero_grad(set_to_none=True)
|
180 |
+
loss.backward()
|
181 |
+
with torch.no_grad():
|
182 |
+
if self.gs_step < self.training_config.densify_until_iter and not no_densify:
|
183 |
+
self.GS.gaussians.max_radii2D[render_pkg["visibility_filter"]] = torch.max(
|
184 |
+
self.GS.gaussians.max_radii2D[render_pkg["visibility_filter"]],
|
185 |
+
render_pkg["radii"][render_pkg["visibility_filter"]])
|
186 |
+
self.GS.gaussians.add_densification_stats(render_pkg["viewspace_points"],
|
187 |
+
render_pkg["visibility_filter"])
|
188 |
+
|
189 |
+
# Optimizer step
|
190 |
+
self.GS_optimizer.step()
|
191 |
+
self.GS_optimizer.zero_grad(set_to_none=True)
|
192 |
+
return render_pkg["radii"]
|
193 |
+
|
194 |
+
def densify_and_prune(self, radii = None):
|
195 |
+
# Densification or pruning
|
196 |
+
if self.gs_step < self.training_config.densify_until_iter:
|
197 |
+
if (self.gs_step > self.training_config.densify_from_iter) and \
|
198 |
+
(self.gs_step % self.training_config.densification_interval == 0):
|
199 |
+
size_threshold = 20 if self.gs_step > self.training_config.opacity_reset_interval else None
|
200 |
+
self.GS.gaussians.densify_and_prune(self.training_config.densify_grad_threshold,
|
201 |
+
0.005,
|
202 |
+
self.GS.scene.cameras_extent,
|
203 |
+
size_threshold, radii)
|
204 |
+
if self.gs_step % self.training_config.opacity_reset_interval == 0 or (
|
205 |
+
self.dataset_white_background and self.gs_step == self.training_config.densify_from_iter):
|
206 |
+
self.GS.gaussians.reset_opacity()
|
207 |
+
|
208 |
+
|
209 |
+
|
210 |
+
def save_model(self):
|
211 |
+
print("\n[ITER {}] Saving Gaussians".format(self.gs_step))
|
212 |
+
self.scene.save(self.gs_step)
|
213 |
+
print("\n[ITER {}] Saving Checkpoint".format(self.gs_step))
|
214 |
+
torch.save((self.GS.gaussians.capture(), self.gs_step),
|
215 |
+
self.scene.model_path + "/chkpnt" + str(self.gs_step) + ".pth")
|
216 |
+
|
217 |
+
|
218 |
+
def init_with_corr(self, cfg, verbose=False, roma_model=None):
|
219 |
+
"""
|
220 |
+
Initializes image with matchings. Also removes SfM init points.
|
221 |
+
Args:
|
222 |
+
cfg: configuration part named init_wC. Check train.yaml
|
223 |
+
verbose: whether you want to print intermediate results. Useful for debug.
|
224 |
+
roma_model: optionally you can pass here preinit RoMA model to avoid reinit
|
225 |
+
it every time.
|
226 |
+
"""
|
227 |
+
if not cfg.use:
|
228 |
+
return None
|
229 |
+
N_splats_at_init = len(self.GS.gaussians._xyz)
|
230 |
+
print("N_splats_at_init:", N_splats_at_init)
|
231 |
+
camera_set, selected_indices, visualization_dict = init_gaussians_with_corr(
|
232 |
+
self.GS.gaussians,
|
233 |
+
self.scene,
|
234 |
+
cfg,
|
235 |
+
self.device,
|
236 |
+
verbose=verbose,
|
237 |
+
roma_model=roma_model)
|
238 |
+
|
239 |
+
# Remove SfM points and leave only matchings inits
|
240 |
+
if not cfg.add_SfM_init:
|
241 |
+
with torch.no_grad():
|
242 |
+
N_splats_after_init = len(self.GS.gaussians._xyz)
|
243 |
+
print("N_splats_after_init:", N_splats_after_init)
|
244 |
+
self.gaussians.tmp_radii = torch.zeros(self.gaussians._xyz.shape[0]).to(self.device)
|
245 |
+
mask = torch.concat([torch.ones(N_splats_at_init, dtype=torch.bool),
|
246 |
+
torch.zeros(N_splats_after_init-N_splats_at_init, dtype=torch.bool)],
|
247 |
+
axis=0)
|
248 |
+
self.GS.gaussians.prune_points(mask)
|
249 |
+
with torch.no_grad():
|
250 |
+
gaussians = self.gaussians
|
251 |
+
gaussians._scaling = gaussians.scaling_inverse_activation(gaussians.scaling_activation(gaussians._scaling)*0.5)
|
252 |
+
return visualization_dict
|
253 |
+
|
254 |
+
|
255 |
+
def prune(self, radii, min_opacity=0.005):
|
256 |
+
self.GS.gaussians.tmp_radii = radii
|
257 |
+
if self.gs_step < self.training_config.densify_until_iter:
|
258 |
+
prune_mask = (self.GS.gaussians.get_opacity < min_opacity).squeeze()
|
259 |
+
self.GS.gaussians.prune_points(prune_mask)
|
260 |
+
torch.cuda.empty_cache()
|
261 |
+
self.GS.gaussians.tmp_radii = None
|
262 |
+
|
source/utils_aux.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Perlin noise code taken from https://gist.github.com/adefossez/0646dbe9ed4005480a2407c62aac8869
|
2 |
+
from types import SimpleNamespace
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
import wandb
|
8 |
+
import random
|
9 |
+
import torchvision.transforms as T
|
10 |
+
import torchvision.transforms.functional as F
|
11 |
+
import torch
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
def parse_dict_to_namespace(dict_nested):
|
15 |
+
"""Turns nested dictionary into nested namespaces"""
|
16 |
+
if type(dict_nested) != dict and type(dict_nested) != list: return dict_nested
|
17 |
+
x = SimpleNamespace()
|
18 |
+
for key, val in dict_nested.items():
|
19 |
+
if type(val) == dict:
|
20 |
+
setattr(x, key, parse_dict_to_namespace(val))
|
21 |
+
elif type(val) == list:
|
22 |
+
setattr(x, key, [parse_dict_to_namespace(v) for v in val])
|
23 |
+
else:
|
24 |
+
setattr(x, key, val)
|
25 |
+
return x
|
26 |
+
|
27 |
+
def set_seed(seed=42, cuda=True):
|
28 |
+
random.seed(seed)
|
29 |
+
np.random.seed(seed)
|
30 |
+
torch.manual_seed(seed)
|
31 |
+
if cuda:
|
32 |
+
torch.cuda.manual_seed_all(seed)
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
def log_samples(samples, scores, iteration, caption="Real Samples"):
|
37 |
+
# Create a grid of images
|
38 |
+
grid = torchvision.utils.make_grid(samples)
|
39 |
+
|
40 |
+
# Log the images and scores to wandb
|
41 |
+
wandb.log({
|
42 |
+
f"{caption}_images": [wandb.Image(grid, caption=f"{caption}: {scores}")],
|
43 |
+
}, step = iteration)
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
def pairwise_distances(matrix):
|
48 |
+
"""
|
49 |
+
Computes the pairwise Euclidean distances between all vectors in the input matrix.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
torch.Tensor: Pairwise distance matrix of shape [N, N].
|
56 |
+
"""
|
57 |
+
# Compute squared pairwise distances
|
58 |
+
squared_diff = torch.cdist(matrix, matrix, p=2)
|
59 |
+
return squared_diff
|
60 |
+
|
61 |
+
def k_closest_vectors(matrix, k):
|
62 |
+
"""
|
63 |
+
Finds the k-closest vectors for each vector in the input matrix based on Euclidean distance.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
|
67 |
+
k (int): Number of closest vectors to return for each vector.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
torch.Tensor: Indices of the k-closest vectors for each vector, excluding the vector itself.
|
71 |
+
"""
|
72 |
+
# Compute pairwise distances
|
73 |
+
distances = pairwise_distances(matrix)
|
74 |
+
|
75 |
+
# For each vector, sort distances and get the indices of the k-closest vectors (excluding itself)
|
76 |
+
# Set diagonal distances to infinity to exclude the vector itself from the nearest neighbors
|
77 |
+
distances.fill_diagonal_(float('inf'))
|
78 |
+
|
79 |
+
# Get the indices of the k smallest distances (k-closest vectors)
|
80 |
+
_, indices = torch.topk(distances, k, largest=False, dim=1)
|
81 |
+
|
82 |
+
return indices
|
83 |
+
|
84 |
+
def process_image(image_tensor):
|
85 |
+
image_np = image_tensor.detach().cpu().numpy().transpose(1, 2, 0)
|
86 |
+
return Image.fromarray(np.clip(image_np * 255, 0, 255).astype(np.uint8))
|
87 |
+
|
88 |
+
|
89 |
+
def normalize_keypoints(kpts_np, width, height):
|
90 |
+
kpts_np[:, 0] = kpts_np[:, 0] / width * 2. - 1.
|
91 |
+
kpts_np[:, 1] = kpts_np[:, 1] / height * 2. - 1.
|
92 |
+
return kpts_np
|
source/utils_preprocess.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file contains function for video or image collection preprocessing.
|
2 |
+
# For video we do the preprocessing and select k sharpest frames.
|
3 |
+
# Afterwards scene is constructed
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
from tqdm import tqdm
|
7 |
+
import pycolmap
|
8 |
+
import os
|
9 |
+
import time
|
10 |
+
import tempfile
|
11 |
+
from moviepy import VideoFileClip
|
12 |
+
from matplotlib import pyplot as plt
|
13 |
+
from PIL import Image
|
14 |
+
import cv2
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
WORKDIR = "../outputs/"
|
18 |
+
|
19 |
+
|
20 |
+
def get_rotation_moviepy(video_path):
|
21 |
+
clip = VideoFileClip(video_path)
|
22 |
+
rotation = 0
|
23 |
+
|
24 |
+
try:
|
25 |
+
displaymatrix = clip.reader.infos['inputs'][0]['streams'][2]['metadata'].get('displaymatrix', '')
|
26 |
+
if 'rotation of' in displaymatrix:
|
27 |
+
angle = float(displaymatrix.strip().split('rotation of')[-1].split('degrees')[0])
|
28 |
+
rotation = int(angle) % 360
|
29 |
+
|
30 |
+
except Exception as e:
|
31 |
+
print(f"No displaymatrix rotation found: {e}")
|
32 |
+
|
33 |
+
clip.reader.close()
|
34 |
+
#if clip.audio:
|
35 |
+
# clip.audio.reader.close_proc()
|
36 |
+
|
37 |
+
return rotation
|
38 |
+
|
39 |
+
def resize_max_side(frame, max_size):
|
40 |
+
h, w = frame.shape[:2]
|
41 |
+
scale = max_size / max(h, w)
|
42 |
+
if scale < 1:
|
43 |
+
frame = cv2.resize(frame, (int(w * scale), int(h * scale)))
|
44 |
+
return frame
|
45 |
+
|
46 |
+
def read_video_frames(video_input, k=1, max_size=1024):
|
47 |
+
"""
|
48 |
+
Extracts every k-th frame from a video or list of images, resizes to max size, and returns frames as list.
|
49 |
+
|
50 |
+
Parameters:
|
51 |
+
video_input (str, file-like, or list): Path to video file, file-like object, or list of image files.
|
52 |
+
k (int): Interval for frame extraction (every k-th frame).
|
53 |
+
max_size (int): Maximum size for width or height after resizing.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
frames (list): List of resized frames (numpy arrays).
|
57 |
+
"""
|
58 |
+
# Handle list of image files (not single video in a list)
|
59 |
+
if isinstance(video_input, list):
|
60 |
+
# If it's a single video in a list, treat it as video
|
61 |
+
if len(video_input) == 1 and video_input[0].name.endswith(('.mp4', '.avi', '.mov')):
|
62 |
+
video_input = video_input[0] # unwrap single video file
|
63 |
+
else:
|
64 |
+
# Treat as list of images
|
65 |
+
frames = []
|
66 |
+
for img_file in video_input:
|
67 |
+
img = Image.open(img_file.name).convert("RGB")
|
68 |
+
img.thumbnail((max_size, max_size))
|
69 |
+
frames.append(np.array(img)[...,::-1])
|
70 |
+
return frames
|
71 |
+
|
72 |
+
# Handle file-like or path
|
73 |
+
if hasattr(video_input, 'name'):
|
74 |
+
video_path = video_input.name
|
75 |
+
elif isinstance(video_input, (str, os.PathLike)):
|
76 |
+
video_path = str(video_input)
|
77 |
+
else:
|
78 |
+
raise ValueError("Unsupported video input type. Must be a filepath, file-like object, or list of images.")
|
79 |
+
|
80 |
+
|
81 |
+
cap = cv2.VideoCapture(video_path)
|
82 |
+
if not cap.isOpened():
|
83 |
+
raise ValueError(f"Error: Could not open video {video_path}.")
|
84 |
+
|
85 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
86 |
+
frame_count = 0
|
87 |
+
frames = []
|
88 |
+
|
89 |
+
with tqdm(total=total_frames // k, desc="Processing Video", unit="frame") as pbar:
|
90 |
+
while True:
|
91 |
+
ret, frame = cap.read()
|
92 |
+
if not ret:
|
93 |
+
break
|
94 |
+
if frame_count % k == 0:
|
95 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
96 |
+
h, w = frame.shape[:2]
|
97 |
+
scale = max(h, w) / max_size
|
98 |
+
if scale > 1:
|
99 |
+
frame = cv2.resize(frame, (int(w / scale), int(h / scale)))
|
100 |
+
frames.append(frame[...,[2,1,0]])
|
101 |
+
pbar.update(1)
|
102 |
+
frame_count += 1
|
103 |
+
|
104 |
+
cap.release()
|
105 |
+
return frames
|
106 |
+
|
107 |
+
def resize_max_side(frame, max_size):
|
108 |
+
"""
|
109 |
+
Resizes the frame so that its largest side equals max_size, maintaining aspect ratio.
|
110 |
+
"""
|
111 |
+
height, width = frame.shape[:2]
|
112 |
+
max_dim = max(height, width)
|
113 |
+
|
114 |
+
if max_dim <= max_size:
|
115 |
+
return frame # No need to resize
|
116 |
+
|
117 |
+
scale = max_size / max_dim
|
118 |
+
new_width = int(width * scale)
|
119 |
+
new_height = int(height * scale)
|
120 |
+
|
121 |
+
resized_frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
122 |
+
return resized_frame
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
def variance_of_laplacian(image):
|
127 |
+
# compute the Laplacian of the image and then return the focus
|
128 |
+
# measure, which is simply the variance of the Laplacian
|
129 |
+
return cv2.Laplacian(image, cv2.CV_64F).var()
|
130 |
+
|
131 |
+
def process_all_frames(IMG_FOLDER = '/scratch/datasets/hq_data/night2_all_frames',
|
132 |
+
to_visualize=False,
|
133 |
+
save_images=True):
|
134 |
+
dict_scores = {}
|
135 |
+
for idx, img_name in tqdm(enumerate(sorted([x for x in os.listdir(IMG_FOLDER) if '.png' in x]))):
|
136 |
+
|
137 |
+
img = cv2.imread(os.path.join(IMG_FOLDER, img_name))#[250:, 100:]
|
138 |
+
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
139 |
+
fm = variance_of_laplacian(gray) + \
|
140 |
+
variance_of_laplacian(cv2.resize(gray, (0,0), fx=0.75, fy=0.75)) + \
|
141 |
+
variance_of_laplacian(cv2.resize(gray, (0,0), fx=0.5, fy=0.5)) + \
|
142 |
+
variance_of_laplacian(cv2.resize(gray, (0,0), fx=0.25, fy=0.25))
|
143 |
+
if to_visualize:
|
144 |
+
plt.figure()
|
145 |
+
plt.title(f"Laplacian score: {fm:.2f}")
|
146 |
+
plt.imshow(img[..., [2,1,0]])
|
147 |
+
plt.show()
|
148 |
+
dict_scores[idx] = {"idx" : idx,
|
149 |
+
"img_name" : img_name,
|
150 |
+
"score" : fm}
|
151 |
+
if save_images:
|
152 |
+
dict_scores[idx]["img"] = img
|
153 |
+
|
154 |
+
return dict_scores
|
155 |
+
|
156 |
+
def select_optimal_frames(scores, k):
|
157 |
+
"""
|
158 |
+
Selects a minimal subset of frames while ensuring no gaps exceed k.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
scores (list of float): List of scores where index represents frame number.
|
162 |
+
k (int): Maximum allowed gap between selected frames.
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
list of int: Indices of selected frames.
|
166 |
+
"""
|
167 |
+
n = len(scores)
|
168 |
+
selected = [0, n-1]
|
169 |
+
i = 0 # Start at the first frame
|
170 |
+
|
171 |
+
while i < n:
|
172 |
+
# Find the best frame to select within the next k frames
|
173 |
+
best_idx = max(range(i, min(i + k + 1, n)), key=lambda x: scores[x], default=None)
|
174 |
+
|
175 |
+
if best_idx is None:
|
176 |
+
break # No more frames left
|
177 |
+
|
178 |
+
selected.append(best_idx)
|
179 |
+
i = best_idx + k + 1 # Move forward, ensuring gaps stay within k
|
180 |
+
|
181 |
+
return sorted(selected)
|
182 |
+
|
183 |
+
|
184 |
+
def variance_of_laplacian(image):
|
185 |
+
"""
|
186 |
+
Compute the variance of Laplacian as a focus measure.
|
187 |
+
"""
|
188 |
+
return cv2.Laplacian(image, cv2.CV_64F).var()
|
189 |
+
|
190 |
+
def preprocess_frames(frames, verbose=False):
|
191 |
+
"""
|
192 |
+
Compute sharpness scores for a list of frames using multi-scale Laplacian variance.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
frames (list of np.ndarray): List of frames (BGR images).
|
196 |
+
verbose (bool): If True, print scores.
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
list of float: Sharpness scores for each frame.
|
200 |
+
"""
|
201 |
+
scores = []
|
202 |
+
|
203 |
+
for idx, frame in enumerate(tqdm(frames, desc="Scoring frames")):
|
204 |
+
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
205 |
+
|
206 |
+
fm = (
|
207 |
+
variance_of_laplacian(gray) +
|
208 |
+
variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.75, fy=0.75)) +
|
209 |
+
variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.5, fy=0.5)) +
|
210 |
+
variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.25, fy=0.25))
|
211 |
+
)
|
212 |
+
|
213 |
+
if verbose:
|
214 |
+
print(f"Frame {idx}: Sharpness Score = {fm:.2f}")
|
215 |
+
|
216 |
+
scores.append(fm)
|
217 |
+
|
218 |
+
return scores
|
219 |
+
|
220 |
+
def select_optimal_frames(scores, k):
|
221 |
+
"""
|
222 |
+
Selects k frames by splitting into k segments and picking the sharpest frame from each.
|
223 |
+
|
224 |
+
Args:
|
225 |
+
scores (list of float): List of sharpness scores.
|
226 |
+
k (int): Number of frames to select.
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
list of int: Indices of selected frames.
|
230 |
+
"""
|
231 |
+
n = len(scores)
|
232 |
+
selected_indices = []
|
233 |
+
segment_size = n // k
|
234 |
+
|
235 |
+
for i in range(k):
|
236 |
+
start = i * segment_size
|
237 |
+
end = (i + 1) * segment_size if i < k - 1 else n # Last chunk may be larger
|
238 |
+
segment_scores = scores[start:end]
|
239 |
+
|
240 |
+
if len(segment_scores) == 0:
|
241 |
+
continue # Safety check if some segment is empty
|
242 |
+
|
243 |
+
best_in_segment = start + np.argmax(segment_scores)
|
244 |
+
selected_indices.append(best_in_segment)
|
245 |
+
|
246 |
+
return sorted(selected_indices)
|
247 |
+
|
248 |
+
def save_frames_to_scene_dir(frames, scene_dir):
|
249 |
+
"""
|
250 |
+
Saves a list of frames into the target scene directory under 'images/' subfolder.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
frames (list of np.ndarray): List of frames (BGR images) to save.
|
254 |
+
scene_dir (str): Target path where 'images/' subfolder will be created.
|
255 |
+
"""
|
256 |
+
images_dir = os.path.join(scene_dir, "images")
|
257 |
+
os.makedirs(images_dir, exist_ok=True)
|
258 |
+
|
259 |
+
for idx, frame in enumerate(frames):
|
260 |
+
filename = os.path.join(images_dir, f"{idx:08d}.png") # 00000000.png, 00000001.png, etc.
|
261 |
+
cv2.imwrite(filename, frame)
|
262 |
+
|
263 |
+
print(f"Saved {len(frames)} frames to {images_dir}")
|
264 |
+
|
265 |
+
|
266 |
+
def run_colmap_on_scene(scene_dir):
|
267 |
+
"""
|
268 |
+
Runs feature extraction, matching, and mapping on all images inside scene_dir/images using pycolmap.
|
269 |
+
|
270 |
+
Args:
|
271 |
+
scene_dir (str): Path to scene directory containing 'images' folder.
|
272 |
+
|
273 |
+
TODO: if the function hasn't managed to match all the frames either increase image size,
|
274 |
+
increase number of features or just remove those frames from the folder scene_dir/images
|
275 |
+
"""
|
276 |
+
start_time = time.time()
|
277 |
+
print(f"Running COLMAP pipeline on all images inside {scene_dir}")
|
278 |
+
|
279 |
+
# Setup paths
|
280 |
+
database_path = os.path.join(scene_dir, "database.db")
|
281 |
+
sparse_path = os.path.join(scene_dir, "sparse")
|
282 |
+
image_dir = os.path.join(scene_dir, "images")
|
283 |
+
|
284 |
+
# Make sure output directories exist
|
285 |
+
os.makedirs(sparse_path, exist_ok=True)
|
286 |
+
|
287 |
+
# Step 1: Feature Extraction
|
288 |
+
pycolmap.extract_features(
|
289 |
+
database_path,
|
290 |
+
image_dir,
|
291 |
+
sift_options={
|
292 |
+
"max_num_features": 512 * 2,
|
293 |
+
"max_image_size": 512 * 1,
|
294 |
+
}
|
295 |
+
)
|
296 |
+
print(f"Finished feature extraction in {(time.time() - start_time):.2f}s.")
|
297 |
+
|
298 |
+
# Step 2: Feature Matching
|
299 |
+
pycolmap.match_exhaustive(database_path)
|
300 |
+
print(f"Finished feature matching in {(time.time() - start_time):.2f}s.")
|
301 |
+
|
302 |
+
# Step 3: Mapping
|
303 |
+
pipeline_options = pycolmap.IncrementalPipelineOptions()
|
304 |
+
pipeline_options.min_num_matches = 15
|
305 |
+
pipeline_options.multiple_models = True
|
306 |
+
pipeline_options.max_num_models = 50
|
307 |
+
pipeline_options.max_model_overlap = 20
|
308 |
+
pipeline_options.min_model_size = 10
|
309 |
+
pipeline_options.extract_colors = True
|
310 |
+
pipeline_options.num_threads = 8
|
311 |
+
pipeline_options.mapper.init_min_num_inliers = 30
|
312 |
+
pipeline_options.mapper.init_max_error = 8.0
|
313 |
+
pipeline_options.mapper.init_min_tri_angle = 5.0
|
314 |
+
|
315 |
+
reconstruction = pycolmap.incremental_mapping(
|
316 |
+
database_path=database_path,
|
317 |
+
image_path=image_dir,
|
318 |
+
output_path=sparse_path,
|
319 |
+
options=pipeline_options,
|
320 |
+
)
|
321 |
+
print(f"Finished incremental mapping in {(time.time() - start_time):.2f}s.")
|
322 |
+
|
323 |
+
# Step 4: Post-process Cameras to SIMPLE_PINHOLE
|
324 |
+
recon_path = os.path.join(sparse_path, "0")
|
325 |
+
reconstruction = pycolmap.Reconstruction(recon_path)
|
326 |
+
|
327 |
+
for cam in reconstruction.cameras.values():
|
328 |
+
cam.model = 'SIMPLE_PINHOLE'
|
329 |
+
cam.params = cam.params[:3] # Keep only [f, cx, cy]
|
330 |
+
|
331 |
+
reconstruction.write(recon_path)
|
332 |
+
|
333 |
+
print(f"Total pipeline time: {(time.time() - start_time):.2f}s.")
|
334 |
+
|
source/vggt_to_colmap.py
ADDED
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Reuse code taken from the implementation of atakan-topaloglu:
|
2 |
+
# https://github.com/atakan-topaloglu/vggt/blob/main/vggt_to_colmap.py
|
3 |
+
|
4 |
+
import os
|
5 |
+
import argparse
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import glob
|
9 |
+
import struct
|
10 |
+
from scipy.spatial.transform import Rotation
|
11 |
+
import sys
|
12 |
+
from PIL import Image
|
13 |
+
import cv2
|
14 |
+
import requests
|
15 |
+
import tempfile
|
16 |
+
|
17 |
+
sys.path.append("submodules/vggt/")
|
18 |
+
|
19 |
+
from vggt.models.vggt import VGGT
|
20 |
+
from vggt.utils.load_fn import load_and_preprocess_images
|
21 |
+
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
22 |
+
from vggt.utils.geometry import unproject_depth_map_to_point_map
|
23 |
+
|
24 |
+
def load_model(device=None):
|
25 |
+
"""Load and initialize the VGGT model."""
|
26 |
+
if device is None:
|
27 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
28 |
+
print(f"Using device: {device}")
|
29 |
+
|
30 |
+
model = VGGT.from_pretrained("facebook/VGGT-1B")
|
31 |
+
|
32 |
+
# model = VGGT()
|
33 |
+
# _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
|
34 |
+
# model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
|
35 |
+
|
36 |
+
model.eval()
|
37 |
+
model = model.to(device)
|
38 |
+
return model, device
|
39 |
+
|
40 |
+
def process_images(image_dir, model, device):
|
41 |
+
"""Process images with VGGT and return predictions."""
|
42 |
+
image_names = glob.glob(os.path.join(image_dir, "*"))
|
43 |
+
image_names = sorted([f for f in image_names if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
|
44 |
+
print(f"Found {len(image_names)} images")
|
45 |
+
|
46 |
+
if len(image_names) == 0:
|
47 |
+
raise ValueError(f"No images found in {image_dir}")
|
48 |
+
|
49 |
+
original_images = []
|
50 |
+
for img_path in image_names:
|
51 |
+
img = Image.open(img_path).convert('RGB')
|
52 |
+
original_images.append(np.array(img))
|
53 |
+
|
54 |
+
images = load_and_preprocess_images(image_names).to(device)
|
55 |
+
print(f"Preprocessed images shape: {images.shape}")
|
56 |
+
|
57 |
+
print("Running inference...")
|
58 |
+
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
|
59 |
+
|
60 |
+
with torch.no_grad():
|
61 |
+
with torch.cuda.amp.autocast(dtype=dtype):
|
62 |
+
predictions = model(images)
|
63 |
+
|
64 |
+
print("Converting pose encoding to camera parameters...")
|
65 |
+
extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
|
66 |
+
predictions["extrinsic"] = extrinsic
|
67 |
+
predictions["intrinsic"] = intrinsic
|
68 |
+
|
69 |
+
for key in predictions.keys():
|
70 |
+
if isinstance(predictions[key], torch.Tensor):
|
71 |
+
predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension
|
72 |
+
|
73 |
+
print("Computing 3D points from depth maps...")
|
74 |
+
depth_map = predictions["depth"] # (S, H, W, 1)
|
75 |
+
world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"])
|
76 |
+
predictions["world_points_from_depth"] = world_points
|
77 |
+
|
78 |
+
predictions["original_images"] = original_images
|
79 |
+
|
80 |
+
S, H, W = world_points.shape[:3]
|
81 |
+
normalized_images = np.zeros((S, H, W, 3), dtype=np.float32)
|
82 |
+
|
83 |
+
for i, img in enumerate(original_images):
|
84 |
+
resized_img = cv2.resize(img, (W, H))
|
85 |
+
normalized_images[i] = resized_img / 255.0
|
86 |
+
|
87 |
+
predictions["images"] = normalized_images
|
88 |
+
|
89 |
+
return predictions, image_names
|
90 |
+
|
91 |
+
def extrinsic_to_colmap_format(extrinsics):
|
92 |
+
"""Convert extrinsic matrices to COLMAP format (quaternion + translation)."""
|
93 |
+
num_cameras = extrinsics.shape[0]
|
94 |
+
quaternions = []
|
95 |
+
translations = []
|
96 |
+
|
97 |
+
for i in range(num_cameras):
|
98 |
+
# VGGT's extrinsic is camera-to-world (R|t) format
|
99 |
+
R = extrinsics[i, :3, :3]
|
100 |
+
t = extrinsics[i, :3, 3]
|
101 |
+
|
102 |
+
# Convert rotation matrix to quaternion
|
103 |
+
# COLMAP quaternion format is [qw, qx, qy, qz]
|
104 |
+
rot = Rotation.from_matrix(R)
|
105 |
+
quat = rot.as_quat() # scipy returns [x, y, z, w]
|
106 |
+
quat = np.array([quat[3], quat[0], quat[1], quat[2]]) # Convert to [w, x, y, z]
|
107 |
+
|
108 |
+
quaternions.append(quat)
|
109 |
+
translations.append(t)
|
110 |
+
|
111 |
+
return np.array(quaternions), np.array(translations)
|
112 |
+
|
113 |
+
def download_file_from_url(url, filename):
|
114 |
+
"""Downloads a file from a URL, handling redirects."""
|
115 |
+
try:
|
116 |
+
response = requests.get(url, allow_redirects=False)
|
117 |
+
response.raise_for_status()
|
118 |
+
|
119 |
+
if response.status_code == 302:
|
120 |
+
redirect_url = response.headers["Location"]
|
121 |
+
response = requests.get(redirect_url, stream=True)
|
122 |
+
response.raise_for_status()
|
123 |
+
else:
|
124 |
+
response = requests.get(url, stream=True)
|
125 |
+
response.raise_for_status()
|
126 |
+
|
127 |
+
with open(filename, "wb") as f:
|
128 |
+
for chunk in response.iter_content(chunk_size=8192):
|
129 |
+
f.write(chunk)
|
130 |
+
print(f"Downloaded {filename} successfully.")
|
131 |
+
return True
|
132 |
+
|
133 |
+
except requests.exceptions.RequestException as e:
|
134 |
+
print(f"Error downloading file: {e}")
|
135 |
+
return False
|
136 |
+
|
137 |
+
def segment_sky(image_path, onnx_session, mask_filename=None):
|
138 |
+
"""
|
139 |
+
Segments sky from an image using an ONNX model.
|
140 |
+
"""
|
141 |
+
image = cv2.imread(image_path)
|
142 |
+
|
143 |
+
result_map = run_skyseg(onnx_session, [320, 320], image)
|
144 |
+
result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0]))
|
145 |
+
|
146 |
+
# Fix: Invert the mask so that 255 = non-sky, 0 = sky
|
147 |
+
# The model outputs low values for sky, high values for non-sky
|
148 |
+
output_mask = np.zeros_like(result_map_original)
|
149 |
+
output_mask[result_map_original < 32] = 255 # Use threshold of 32
|
150 |
+
|
151 |
+
if mask_filename is not None:
|
152 |
+
os.makedirs(os.path.dirname(mask_filename), exist_ok=True)
|
153 |
+
cv2.imwrite(mask_filename, output_mask)
|
154 |
+
|
155 |
+
return output_mask
|
156 |
+
|
157 |
+
def run_skyseg(onnx_session, input_size, image):
|
158 |
+
"""
|
159 |
+
Runs sky segmentation inference using ONNX model.
|
160 |
+
"""
|
161 |
+
import copy
|
162 |
+
|
163 |
+
temp_image = copy.deepcopy(image)
|
164 |
+
resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1]))
|
165 |
+
x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB)
|
166 |
+
x = np.array(x, dtype=np.float32)
|
167 |
+
mean = [0.485, 0.456, 0.406]
|
168 |
+
std = [0.229, 0.224, 0.225]
|
169 |
+
x = (x / 255 - mean) / std
|
170 |
+
x = x.transpose(2, 0, 1)
|
171 |
+
x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32")
|
172 |
+
|
173 |
+
input_name = onnx_session.get_inputs()[0].name
|
174 |
+
output_name = onnx_session.get_outputs()[0].name
|
175 |
+
onnx_result = onnx_session.run([output_name], {input_name: x})
|
176 |
+
|
177 |
+
onnx_result = np.array(onnx_result).squeeze()
|
178 |
+
min_value = np.min(onnx_result)
|
179 |
+
max_value = np.max(onnx_result)
|
180 |
+
onnx_result = (onnx_result - min_value) / (max_value - min_value)
|
181 |
+
onnx_result *= 255
|
182 |
+
onnx_result = onnx_result.astype("uint8")
|
183 |
+
|
184 |
+
return onnx_result
|
185 |
+
|
186 |
+
def filter_and_prepare_points(predictions, conf_threshold, mask_sky=False, mask_black_bg=False,
|
187 |
+
mask_white_bg=False, stride=1, prediction_mode="Depthmap and Camera Branch"):
|
188 |
+
"""
|
189 |
+
Filter points based on confidence and prepare for COLMAP format.
|
190 |
+
Implementation matches the conventions in the original VGGT code.
|
191 |
+
"""
|
192 |
+
|
193 |
+
if "Pointmap" in prediction_mode:
|
194 |
+
print("Using Pointmap Branch")
|
195 |
+
if "world_points" in predictions:
|
196 |
+
pred_world_points = predictions["world_points"]
|
197 |
+
pred_world_points_conf = predictions.get("world_points_conf", np.ones_like(pred_world_points[..., 0]))
|
198 |
+
else:
|
199 |
+
print("Warning: world_points not found in predictions, falling back to depth-based points")
|
200 |
+
pred_world_points = predictions["world_points_from_depth"]
|
201 |
+
pred_world_points_conf = predictions.get("depth_conf", np.ones_like(pred_world_points[..., 0]))
|
202 |
+
else:
|
203 |
+
print("Using Depthmap and Camera Branch")
|
204 |
+
pred_world_points = predictions["world_points_from_depth"]
|
205 |
+
pred_world_points_conf = predictions.get("depth_conf", np.ones_like(pred_world_points[..., 0]))
|
206 |
+
|
207 |
+
colors_rgb = predictions["images"]
|
208 |
+
|
209 |
+
S, H, W = pred_world_points.shape[:3]
|
210 |
+
if colors_rgb.shape[:3] != (S, H, W):
|
211 |
+
print(f"Reshaping colors_rgb from {colors_rgb.shape} to match {(S, H, W, 3)}")
|
212 |
+
reshaped_colors = np.zeros((S, H, W, 3), dtype=np.float32)
|
213 |
+
for i in range(S):
|
214 |
+
if i < len(colors_rgb):
|
215 |
+
reshaped_colors[i] = cv2.resize(colors_rgb[i], (W, H))
|
216 |
+
colors_rgb = reshaped_colors
|
217 |
+
|
218 |
+
colors_rgb = (colors_rgb * 255).astype(np.uint8)
|
219 |
+
|
220 |
+
if mask_sky:
|
221 |
+
print("Applying sky segmentation mask")
|
222 |
+
try:
|
223 |
+
import onnxruntime
|
224 |
+
|
225 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
226 |
+
print(f"Created temporary directory for sky segmentation: {temp_dir}")
|
227 |
+
temp_images_dir = os.path.join(temp_dir, "images")
|
228 |
+
sky_masks_dir = os.path.join(temp_dir, "sky_masks")
|
229 |
+
os.makedirs(temp_images_dir, exist_ok=True)
|
230 |
+
os.makedirs(sky_masks_dir, exist_ok=True)
|
231 |
+
|
232 |
+
image_list = []
|
233 |
+
for i, img in enumerate(colors_rgb):
|
234 |
+
img_path = os.path.join(temp_images_dir, f"image_{i:04d}.png")
|
235 |
+
image_list.append(img_path)
|
236 |
+
cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
|
237 |
+
|
238 |
+
|
239 |
+
skyseg_path = os.path.join(temp_dir, "skyseg.onnx")
|
240 |
+
if not os.path.exists("skyseg.onnx"):
|
241 |
+
print("Downloading skyseg.onnx...")
|
242 |
+
download_success = download_file_from_url(
|
243 |
+
"https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx",
|
244 |
+
skyseg_path
|
245 |
+
)
|
246 |
+
if not download_success:
|
247 |
+
print("Failed to download skyseg model, skipping sky filtering")
|
248 |
+
mask_sky = False
|
249 |
+
else:
|
250 |
+
|
251 |
+
import shutil
|
252 |
+
shutil.copy("skyseg.onnx", skyseg_path)
|
253 |
+
|
254 |
+
if mask_sky:
|
255 |
+
skyseg_session = onnxruntime.InferenceSession(skyseg_path)
|
256 |
+
sky_mask_list = []
|
257 |
+
|
258 |
+
for img_path in image_list:
|
259 |
+
mask_path = os.path.join(sky_masks_dir, os.path.basename(img_path))
|
260 |
+
sky_mask = segment_sky(img_path, skyseg_session, mask_path)
|
261 |
+
|
262 |
+
if sky_mask.shape[0] != H or sky_mask.shape[1] != W:
|
263 |
+
sky_mask = cv2.resize(sky_mask, (W, H))
|
264 |
+
|
265 |
+
sky_mask_list.append(sky_mask)
|
266 |
+
|
267 |
+
sky_mask_array = np.array(sky_mask_list)
|
268 |
+
|
269 |
+
sky_mask_binary = (sky_mask_array > 0.1).astype(np.float32)
|
270 |
+
pred_world_points_conf = pred_world_points_conf * sky_mask_binary
|
271 |
+
print(f"Applied sky mask, shape: {sky_mask_binary.shape}")
|
272 |
+
|
273 |
+
except (ImportError, Exception) as e:
|
274 |
+
print(f"Error in sky segmentation: {e}")
|
275 |
+
mask_sky = False
|
276 |
+
|
277 |
+
vertices_3d = pred_world_points.reshape(-1, 3)
|
278 |
+
conf = pred_world_points_conf.reshape(-1)
|
279 |
+
colors_rgb_flat = colors_rgb.reshape(-1, 3)
|
280 |
+
|
281 |
+
|
282 |
+
|
283 |
+
if len(conf) != len(colors_rgb_flat):
|
284 |
+
print(f"WARNING: Shape mismatch between confidence ({len(conf)}) and colors ({len(colors_rgb_flat)})")
|
285 |
+
min_size = min(len(conf), len(colors_rgb_flat))
|
286 |
+
conf = conf[:min_size]
|
287 |
+
vertices_3d = vertices_3d[:min_size]
|
288 |
+
colors_rgb_flat = colors_rgb_flat[:min_size]
|
289 |
+
|
290 |
+
if conf_threshold == 0.0:
|
291 |
+
conf_thres_value = 0.0
|
292 |
+
else:
|
293 |
+
conf_thres_value = np.percentile(conf, conf_threshold)
|
294 |
+
|
295 |
+
print(f"Using confidence threshold: {conf_threshold}% (value: {conf_thres_value:.4f})")
|
296 |
+
conf_mask = (conf >= conf_thres_value) & (conf > 1e-5)
|
297 |
+
|
298 |
+
if mask_black_bg:
|
299 |
+
print("Filtering black background")
|
300 |
+
black_bg_mask = colors_rgb_flat.sum(axis=1) >= 16
|
301 |
+
conf_mask = conf_mask & black_bg_mask
|
302 |
+
|
303 |
+
if mask_white_bg:
|
304 |
+
print("Filtering white background")
|
305 |
+
white_bg_mask = ~((colors_rgb_flat[:, 0] > 240) & (colors_rgb_flat[:, 1] > 240) & (colors_rgb_flat[:, 2] > 240))
|
306 |
+
conf_mask = conf_mask & white_bg_mask
|
307 |
+
|
308 |
+
filtered_vertices = vertices_3d[conf_mask]
|
309 |
+
filtered_colors = colors_rgb_flat[conf_mask]
|
310 |
+
|
311 |
+
if len(filtered_vertices) == 0:
|
312 |
+
print("Warning: No points remaining after filtering. Using default point.")
|
313 |
+
filtered_vertices = np.array([[0, 0, 0]])
|
314 |
+
filtered_colors = np.array([[200, 200, 200]])
|
315 |
+
|
316 |
+
print(f"Filtered to {len(filtered_vertices)} points")
|
317 |
+
|
318 |
+
points3D = []
|
319 |
+
point_indices = {}
|
320 |
+
image_points2D = [[] for _ in range(len(pred_world_points))]
|
321 |
+
|
322 |
+
print(f"Preparing points for COLMAP format with stride {stride}...")
|
323 |
+
|
324 |
+
total_points = 0
|
325 |
+
for img_idx in range(S):
|
326 |
+
for y in range(0, H, stride):
|
327 |
+
for x in range(0, W, stride):
|
328 |
+
flat_idx = img_idx * H * W + y * W + x
|
329 |
+
|
330 |
+
if flat_idx >= len(conf):
|
331 |
+
continue
|
332 |
+
|
333 |
+
if conf[flat_idx] < conf_thres_value or conf[flat_idx] <= 1e-5:
|
334 |
+
continue
|
335 |
+
|
336 |
+
if mask_black_bg and colors_rgb_flat[flat_idx].sum() < 16:
|
337 |
+
continue
|
338 |
+
|
339 |
+
if mask_white_bg and all(colors_rgb_flat[flat_idx] > 240):
|
340 |
+
continue
|
341 |
+
|
342 |
+
point3D = vertices_3d[flat_idx]
|
343 |
+
rgb = colors_rgb_flat[flat_idx]
|
344 |
+
|
345 |
+
if not np.all(np.isfinite(point3D)):
|
346 |
+
continue
|
347 |
+
|
348 |
+
point_hash = hash_point(point3D, scale=100)
|
349 |
+
|
350 |
+
if point_hash not in point_indices:
|
351 |
+
point_idx = len(points3D)
|
352 |
+
point_indices[point_hash] = point_idx
|
353 |
+
|
354 |
+
point_entry = {
|
355 |
+
"id": point_idx,
|
356 |
+
"xyz": point3D,
|
357 |
+
"rgb": rgb,
|
358 |
+
"error": 1.0,
|
359 |
+
"track": [(img_idx, len(image_points2D[img_idx]))]
|
360 |
+
}
|
361 |
+
points3D.append(point_entry)
|
362 |
+
total_points += 1
|
363 |
+
else:
|
364 |
+
point_idx = point_indices[point_hash]
|
365 |
+
points3D[point_idx]["track"].append((img_idx, len(image_points2D[img_idx])))
|
366 |
+
|
367 |
+
image_points2D[img_idx].append((x, y, point_indices[point_hash]))
|
368 |
+
|
369 |
+
print(f"Prepared {len(points3D)} 3D points with {sum(len(pts) for pts in image_points2D)} observations for COLMAP")
|
370 |
+
return points3D, image_points2D
|
371 |
+
|
372 |
+
def hash_point(point, scale=100):
|
373 |
+
"""Create a hash for a 3D point by quantizing coordinates."""
|
374 |
+
quantized = tuple(np.round(point * scale).astype(int))
|
375 |
+
return hash(quantized)
|
376 |
+
|
377 |
+
def write_colmap_cameras_txt(file_path, intrinsics, image_width, image_height):
|
378 |
+
"""Write camera intrinsics to COLMAP cameras.txt format."""
|
379 |
+
with open(file_path, 'w') as f:
|
380 |
+
f.write("# Camera list with one line of data per camera:\n")
|
381 |
+
f.write("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n")
|
382 |
+
f.write(f"# Number of cameras: {len(intrinsics)}\n")
|
383 |
+
|
384 |
+
for i, intrinsic in enumerate(intrinsics):
|
385 |
+
camera_id = i + 1 # COLMAP uses 1-indexed camera IDs
|
386 |
+
model = "PINHOLE"
|
387 |
+
|
388 |
+
fx = intrinsic[0, 0]
|
389 |
+
fy = intrinsic[1, 1]
|
390 |
+
cx = intrinsic[0, 2]
|
391 |
+
cy = intrinsic[1, 2]
|
392 |
+
|
393 |
+
f.write(f"{camera_id} {model} {image_width} {image_height} {fx} {fy} {cx} {cy}\n")
|
394 |
+
|
395 |
+
def write_colmap_images_txt(file_path, quaternions, translations, image_points2D, image_names):
|
396 |
+
"""Write camera poses and keypoints to COLMAP images.txt format."""
|
397 |
+
with open(file_path, 'w') as f:
|
398 |
+
f.write("# Image list with two lines of data per image:\n")
|
399 |
+
f.write("# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n")
|
400 |
+
f.write("# POINTS2D[] as (X, Y, POINT3D_ID)\n")
|
401 |
+
|
402 |
+
num_points = sum(len(points) for points in image_points2D)
|
403 |
+
avg_points = num_points / len(image_points2D) if image_points2D else 0
|
404 |
+
f.write(f"# Number of images: {len(quaternions)}, mean observations per image: {avg_points:.1f}\n")
|
405 |
+
|
406 |
+
for i in range(len(quaternions)):
|
407 |
+
image_id = i + 1
|
408 |
+
camera_id = i + 1
|
409 |
+
|
410 |
+
qw, qx, qy, qz = quaternions[i]
|
411 |
+
tx, ty, tz = translations[i]
|
412 |
+
|
413 |
+
f.write(f"{image_id} {qw} {qx} {qy} {qz} {tx} {ty} {tz} {camera_id} {os.path.basename(image_names[i])}\n")
|
414 |
+
|
415 |
+
points_line = " ".join([f"{x} {y} {point3d_id+1}" for x, y, point3d_id in image_points2D[i]])
|
416 |
+
f.write(f"{points_line}\n")
|
417 |
+
|
418 |
+
def write_colmap_points3D_txt(file_path, points3D):
|
419 |
+
"""Write 3D points and tracks to COLMAP points3D.txt format."""
|
420 |
+
with open(file_path, 'w') as f:
|
421 |
+
f.write("# 3D point list with one line of data per point:\n")
|
422 |
+
f.write("# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n")
|
423 |
+
|
424 |
+
avg_track_length = sum(len(point["track"]) for point in points3D) / len(points3D) if points3D else 0
|
425 |
+
f.write(f"# Number of points: {len(points3D)}, mean track length: {avg_track_length:.4f}\n")
|
426 |
+
|
427 |
+
for point in points3D:
|
428 |
+
point_id = point["id"] + 1
|
429 |
+
x, y, z = point["xyz"]
|
430 |
+
r, g, b = point["rgb"]
|
431 |
+
error = point["error"]
|
432 |
+
|
433 |
+
track = " ".join([f"{img_id+1} {point2d_idx}" for img_id, point2d_idx in point["track"]])
|
434 |
+
|
435 |
+
f.write(f"{point_id} {x} {y} {z} {int(r)} {int(g)} {int(b)} {error} {track}\n")
|
436 |
+
|
437 |
+
def write_colmap_cameras_bin(file_path, intrinsics, image_width, image_height):
|
438 |
+
"""Write camera intrinsics to COLMAP cameras.bin format."""
|
439 |
+
with open(file_path, 'wb') as fid:
|
440 |
+
# Write number of cameras (uint64)
|
441 |
+
fid.write(struct.pack('<Q', len(intrinsics)))
|
442 |
+
|
443 |
+
for i, intrinsic in enumerate(intrinsics):
|
444 |
+
camera_id = i + 1
|
445 |
+
model_id = 1
|
446 |
+
|
447 |
+
fx = float(intrinsic[0, 0])
|
448 |
+
fy = float(intrinsic[1, 1])
|
449 |
+
cx = float(intrinsic[0, 2])
|
450 |
+
cy = float(intrinsic[1, 2])
|
451 |
+
|
452 |
+
# Camera ID (uint32)
|
453 |
+
fid.write(struct.pack('<I', camera_id))
|
454 |
+
# Model ID (uint32)
|
455 |
+
fid.write(struct.pack('<I', model_id))
|
456 |
+
# Width (uint64)
|
457 |
+
fid.write(struct.pack('<Q', image_width))
|
458 |
+
# Height (uint64)
|
459 |
+
fid.write(struct.pack('<Q', image_height))
|
460 |
+
|
461 |
+
# Parameters (double)
|
462 |
+
fid.write(struct.pack('<dddd', fx, fy, cx, cy))
|
463 |
+
|
464 |
+
def write_colmap_images_bin(file_path, quaternions, translations, image_points2D, image_names):
|
465 |
+
"""Write camera poses and keypoints to COLMAP images.bin format."""
|
466 |
+
with open(file_path, 'wb') as fid:
|
467 |
+
# Write number of images (uint64)
|
468 |
+
fid.write(struct.pack('<Q', len(quaternions)))
|
469 |
+
|
470 |
+
for i in range(len(quaternions)):
|
471 |
+
image_id = i + 1
|
472 |
+
camera_id = i + 1
|
473 |
+
|
474 |
+
qw, qx, qy, qz = quaternions[i].astype(float)
|
475 |
+
tx, ty, tz = translations[i].astype(float)
|
476 |
+
|
477 |
+
image_name = os.path.basename(image_names[i]).encode()
|
478 |
+
points = image_points2D[i]
|
479 |
+
|
480 |
+
# Image ID (uint32)
|
481 |
+
fid.write(struct.pack('<I', image_id))
|
482 |
+
# Quaternion (double): qw, qx, qy, qz
|
483 |
+
fid.write(struct.pack('<dddd', qw, qx, qy, qz))
|
484 |
+
# Translation (double): tx, ty, tz
|
485 |
+
fid.write(struct.pack('<ddd', tx, ty, tz))
|
486 |
+
# Camera ID (uint32)
|
487 |
+
fid.write(struct.pack('<I', camera_id))
|
488 |
+
# Image name
|
489 |
+
fid.write(struct.pack('<I', len(image_name)))
|
490 |
+
fid.write(image_name)
|
491 |
+
|
492 |
+
# Write number of 2D points (uint64)
|
493 |
+
fid.write(struct.pack('<Q', len(points)))
|
494 |
+
|
495 |
+
# Write 2D points: x, y, point3D_id
|
496 |
+
for x, y, point3d_id in points:
|
497 |
+
fid.write(struct.pack('<dd', float(x), float(y)))
|
498 |
+
fid.write(struct.pack('<Q', point3d_id + 1))
|
499 |
+
|
500 |
+
def write_colmap_points3D_bin(file_path, points3D):
|
501 |
+
"""Write 3D points and tracks to COLMAP points3D.bin format."""
|
502 |
+
with open(file_path, 'wb') as fid:
|
503 |
+
# Write number of points (uint64)
|
504 |
+
fid.write(struct.pack('<Q', len(points3D)))
|
505 |
+
|
506 |
+
for point in points3D:
|
507 |
+
point_id = point["id"] + 1
|
508 |
+
x, y, z = point["xyz"].astype(float)
|
509 |
+
r, g, b = point["rgb"].astype(np.uint8)
|
510 |
+
error = float(point["error"])
|
511 |
+
track = point["track"]
|
512 |
+
|
513 |
+
# Point ID (uint64)
|
514 |
+
fid.write(struct.pack('<Q', point_id))
|
515 |
+
# Position (double): x, y, z
|
516 |
+
fid.write(struct.pack('<ddd', x, y, z))
|
517 |
+
# Color (uint8): r, g, b
|
518 |
+
fid.write(struct.pack('<BBB', int(r), int(g), int(b)))
|
519 |
+
# Error (double)
|
520 |
+
fid.write(struct.pack('<d', error))
|
521 |
+
|
522 |
+
# Track: list of (image_id, point2D_idx)
|
523 |
+
fid.write(struct.pack('<Q', len(track)))
|
524 |
+
for img_id, point2d_idx in track:
|
525 |
+
fid.write(struct.pack('<II', img_id + 1, point2d_idx))
|
526 |
+
|
527 |
+
def main():
|
528 |
+
parser = argparse.ArgumentParser(description="Convert images to COLMAP format using VGGT")
|
529 |
+
parser.add_argument("--image_dir", type=str, required=True,
|
530 |
+
help="Directory containing input images")
|
531 |
+
parser.add_argument("--output_dir", type=str, default="colmap_output",
|
532 |
+
help="Directory to save COLMAP files")
|
533 |
+
parser.add_argument("--conf_threshold", type=float, default=50.0,
|
534 |
+
help="Confidence threshold (0-100%) for including points")
|
535 |
+
parser.add_argument("--mask_sky", action="store_true",
|
536 |
+
help="Filter out points likely to be sky")
|
537 |
+
parser.add_argument("--mask_black_bg", action="store_true",
|
538 |
+
help="Filter out points with very dark/black color")
|
539 |
+
parser.add_argument("--mask_white_bg", action="store_true",
|
540 |
+
help="Filter out points with very bright/white color")
|
541 |
+
parser.add_argument("--binary", action="store_true",
|
542 |
+
help="Output binary COLMAP files instead of text")
|
543 |
+
parser.add_argument("--stride", type=int, default=1,
|
544 |
+
help="Stride for point sampling (higher = fewer points)")
|
545 |
+
parser.add_argument("--prediction_mode", type=str, default="Depthmap and Camera Branch",
|
546 |
+
choices=["Depthmap and Camera Branch", "Pointmap Branch"],
|
547 |
+
help="Which prediction branch to use")
|
548 |
+
|
549 |
+
args = parser.parse_args()
|
550 |
+
|
551 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
552 |
+
|
553 |
+
model, device = load_model()
|
554 |
+
|
555 |
+
predictions, image_names = process_images(args.image_dir, model, device)
|
556 |
+
|
557 |
+
print("Converting camera parameters to COLMAP format...")
|
558 |
+
quaternions, translations = extrinsic_to_colmap_format(predictions["extrinsic"])
|
559 |
+
|
560 |
+
print(f"Filtering points with confidence threshold {args.conf_threshold}% and stride {args.stride}...")
|
561 |
+
points3D, image_points2D = filter_and_prepare_points(
|
562 |
+
predictions,
|
563 |
+
args.conf_threshold,
|
564 |
+
mask_sky=args.mask_sky,
|
565 |
+
mask_black_bg=args.mask_black_bg,
|
566 |
+
mask_white_bg=args.mask_white_bg,
|
567 |
+
stride=args.stride,
|
568 |
+
prediction_mode=args.prediction_mode
|
569 |
+
)
|
570 |
+
|
571 |
+
height, width = predictions["depth"].shape[1:3]
|
572 |
+
|
573 |
+
print(f"Writing {'binary' if args.binary else 'text'} COLMAP files to {args.output_dir}...")
|
574 |
+
if args.binary:
|
575 |
+
write_colmap_cameras_bin(
|
576 |
+
os.path.join(args.output_dir, "cameras.bin"),
|
577 |
+
predictions["intrinsic"], width, height)
|
578 |
+
write_colmap_images_bin(
|
579 |
+
os.path.join(args.output_dir, "images.bin"),
|
580 |
+
quaternions, translations, image_points2D, image_names)
|
581 |
+
write_colmap_points3D_bin(
|
582 |
+
os.path.join(args.output_dir, "points3D.bin"),
|
583 |
+
points3D)
|
584 |
+
else:
|
585 |
+
write_colmap_cameras_txt(
|
586 |
+
os.path.join(args.output_dir, "cameras.txt"),
|
587 |
+
predictions["intrinsic"], width, height)
|
588 |
+
write_colmap_images_txt(
|
589 |
+
os.path.join(args.output_dir, "images.txt"),
|
590 |
+
quaternions, translations, image_points2D, image_names)
|
591 |
+
write_colmap_points3D_txt(
|
592 |
+
os.path.join(args.output_dir, "points3D.txt"),
|
593 |
+
points3D)
|
594 |
+
|
595 |
+
print(f"COLMAP files successfully written to {args.output_dir}")
|
596 |
+
|
597 |
+
if __name__ == "__main__":
|
598 |
+
main()
|
source/visualization.py
ADDED
@@ -0,0 +1,1072 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from matplotlib import pyplot as plt
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from typing import List
|
7 |
+
import sys
|
8 |
+
sys.path.append('./submodules/gaussian-splatting/')
|
9 |
+
from scene.cameras import Camera
|
10 |
+
from PIL import Image
|
11 |
+
import imageio
|
12 |
+
from scipy.interpolate import splprep, splev
|
13 |
+
|
14 |
+
import cv2
|
15 |
+
import numpy as np
|
16 |
+
import plotly.graph_objects as go
|
17 |
+
import numpy as np
|
18 |
+
from scipy.spatial.transform import Rotation as R, Slerp
|
19 |
+
from scipy.spatial import distance_matrix
|
20 |
+
from sklearn.decomposition import PCA
|
21 |
+
from scipy.interpolate import splprep, splev
|
22 |
+
from typing import List
|
23 |
+
from sklearn.mixture import GaussianMixture
|
24 |
+
|
25 |
+
def render_gaussians_rgb(generator3DGS, viewpoint_cam, visualize=False):
|
26 |
+
"""
|
27 |
+
Simply render gaussians from the generator3DGS from the viewpoint_cam.
|
28 |
+
Args:
|
29 |
+
generator3DGS : instance of the Generator3DGS class from the networks.py file
|
30 |
+
viewpoint_cam : camera instance
|
31 |
+
visualize : boolean flag. If True, will call pyplot function and render image inplace
|
32 |
+
Returns:
|
33 |
+
uint8 numpy array with shape (H, W, 3) representing the image
|
34 |
+
"""
|
35 |
+
with torch.no_grad():
|
36 |
+
render_pkg = generator3DGS(viewpoint_cam)
|
37 |
+
image = render_pkg["render"]
|
38 |
+
image_np = image.clone().detach().cpu().numpy().transpose(1, 2, 0)
|
39 |
+
|
40 |
+
# Clip values to be in the range [0, 1]
|
41 |
+
image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8)
|
42 |
+
if visualize:
|
43 |
+
plt.figure(figsize=(12, 8))
|
44 |
+
plt.imshow(image_np)
|
45 |
+
plt.show()
|
46 |
+
|
47 |
+
return image_np
|
48 |
+
|
49 |
+
def render_gaussians_D_scores(generator3DGS, viewpoint_cam, mask=None, mask_channel=0, visualize=False):
|
50 |
+
"""
|
51 |
+
Simply render D_scores of gaussians from the generator3DGS from the viewpoint_cam.
|
52 |
+
Args:
|
53 |
+
generator3DGS : instance of the Generator3DGS class from the networks.py file
|
54 |
+
viewpoint_cam : camera instance
|
55 |
+
visualize : boolean flag. If True, will call pyplot function and render image inplace
|
56 |
+
mask : optional mask to highlight specific gaussians. Must be of shape (N) where N is the numnber
|
57 |
+
of gaussians in generator3DGS.gaussians. Must be a torch tensor of floats, please scale according
|
58 |
+
to how much color you want to have. Recommended mask value is 10.
|
59 |
+
mask_channel: to which color channel should we add mask
|
60 |
+
Returns:
|
61 |
+
uint8 numpy array with shape (H, W, 3) representing the generator3DGS.gaussians.D_scores rendered as colors
|
62 |
+
"""
|
63 |
+
with torch.no_grad():
|
64 |
+
# Visualize D_scores
|
65 |
+
generator3DGS.gaussians._features_dc = generator3DGS.gaussians._features_dc * 1e-4 + \
|
66 |
+
torch.stack([generator3DGS.gaussians.D_scores] * 3, axis=-1)
|
67 |
+
generator3DGS.gaussians._features_rest = generator3DGS.gaussians._features_rest * 1e-4
|
68 |
+
if mask is not None:
|
69 |
+
generator3DGS.gaussians._features_dc[..., mask_channel] += mask.unsqueeze(-1)
|
70 |
+
render_pkg = generator3DGS(viewpoint_cam)
|
71 |
+
image = render_pkg["render"]
|
72 |
+
image_np = image.clone().detach().cpu().numpy().transpose(1, 2, 0)
|
73 |
+
|
74 |
+
# Clip values to be in the range [0, 1]
|
75 |
+
image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8)
|
76 |
+
if visualize:
|
77 |
+
plt.figure(figsize=(12, 8))
|
78 |
+
plt.imshow(image_np)
|
79 |
+
plt.show()
|
80 |
+
|
81 |
+
if mask is not None:
|
82 |
+
generator3DGS.gaussians._features_dc[..., mask_channel] -= mask.unsqueeze(-1)
|
83 |
+
|
84 |
+
generator3DGS.gaussians._features_dc = (generator3DGS.gaussians._features_dc - \
|
85 |
+
torch.stack([generator3DGS.gaussians.D_scores] * 3, axis=-1)) * 1e4
|
86 |
+
generator3DGS.gaussians._features_rest = generator3DGS.gaussians._features_rest * 1e4
|
87 |
+
|
88 |
+
return image_np
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
def normalize(v):
|
93 |
+
"""
|
94 |
+
Normalize a vector to unit length.
|
95 |
+
|
96 |
+
Parameters:
|
97 |
+
v (np.ndarray): Input vector.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
np.ndarray: Unit vector in the same direction as `v`.
|
101 |
+
"""
|
102 |
+
return v / np.linalg.norm(v)
|
103 |
+
|
104 |
+
def look_at_rotation(camera_position: np.ndarray, target: np.ndarray, world_up=np.array([0, 1, 0])):
|
105 |
+
"""
|
106 |
+
Compute a rotation matrix for a camera looking at a target point.
|
107 |
+
|
108 |
+
Parameters:
|
109 |
+
camera_position (np.ndarray): The 3D position of the camera.
|
110 |
+
target (np.ndarray): The point the camera should look at.
|
111 |
+
world_up (np.ndarray): A vector that defines the global 'up' direction.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
np.ndarray: A 3x3 rotation matrix (camera-to-world) with columns [right, up, forward].
|
115 |
+
"""
|
116 |
+
z_axis = normalize(target - camera_position) # Forward direction
|
117 |
+
x_axis = normalize(np.cross(world_up, z_axis)) # Right direction
|
118 |
+
y_axis = np.cross(z_axis, x_axis) # Recomputed up
|
119 |
+
return np.stack([x_axis, y_axis, z_axis], axis=1)
|
120 |
+
|
121 |
+
|
122 |
+
def generate_circular_camera_path(existing_cameras: List[Camera], N: int = 12, radius_scale: float = 1.0, d: float = 2.0) -> List[Camera]:
|
123 |
+
"""
|
124 |
+
Generate a circular path of cameras around an existing camera group,
|
125 |
+
with each new camera oriented to look at the average viewing direction.
|
126 |
+
|
127 |
+
Parameters:
|
128 |
+
existing_cameras (List[Camera]): List of existing camera objects to estimate average orientation and layout.
|
129 |
+
N (int): Number of new cameras to generate along the circular path.
|
130 |
+
radius_scale (float): Scale factor to adjust the radius of the circle.
|
131 |
+
d (float): Distance ahead of each camera used to estimate its look-at point.
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
List[Camera]: A list of newly generated Camera objects forming a circular path and oriented toward a shared view center.
|
135 |
+
"""
|
136 |
+
# Step 1: Compute average camera position
|
137 |
+
center = np.mean([cam.T for cam in existing_cameras], axis=0)
|
138 |
+
|
139 |
+
# Estimate where each camera is looking
|
140 |
+
# d denotes how far ahead each camera sees — you can scale this
|
141 |
+
look_targets = [cam.T + cam.R[:, 2] * d for cam in existing_cameras]
|
142 |
+
center_of_view = np.mean(look_targets, axis=0)
|
143 |
+
|
144 |
+
# Step 2: Define circular plane basis using fixed up vector
|
145 |
+
avg_forward = normalize(np.mean([cam.R[:, 2] for cam in existing_cameras], axis=0))
|
146 |
+
up_guess = np.array([0, 1, 0])
|
147 |
+
right = normalize(np.cross(avg_forward, up_guess))
|
148 |
+
up = normalize(np.cross(right, avg_forward))
|
149 |
+
|
150 |
+
# Step 3: Estimate radius
|
151 |
+
avg_radius = np.mean([np.linalg.norm(cam.T - center) for cam in existing_cameras]) * radius_scale
|
152 |
+
|
153 |
+
# Step 4: Create cameras on a circular path
|
154 |
+
angles = np.linspace(0, 2 * np.pi, N, endpoint=False)
|
155 |
+
reference_cam = existing_cameras[0]
|
156 |
+
new_cameras = []
|
157 |
+
|
158 |
+
|
159 |
+
for i, a in enumerate(angles):
|
160 |
+
position = center + avg_radius * (np.cos(a) * right + np.sin(a) * up)
|
161 |
+
|
162 |
+
if d < 1e-5 or radius_scale < 1e-5:
|
163 |
+
# Use same orientation as the first camera
|
164 |
+
R = reference_cam.R.copy()
|
165 |
+
else:
|
166 |
+
# Change orientation
|
167 |
+
R = look_at_rotation(position, center_of_view)
|
168 |
+
new_cameras.append(Camera(
|
169 |
+
R=R,
|
170 |
+
T=position, # New position
|
171 |
+
FoVx=reference_cam.FoVx,
|
172 |
+
FoVy=reference_cam.FoVy,
|
173 |
+
resolution=(reference_cam.image_width, reference_cam.image_height),
|
174 |
+
colmap_id=-1,
|
175 |
+
depth_params=None,
|
176 |
+
image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
|
177 |
+
invdepthmap=None,
|
178 |
+
image_name=f"circular_a={a:.3f}",
|
179 |
+
uid=i
|
180 |
+
))
|
181 |
+
|
182 |
+
return new_cameras
|
183 |
+
|
184 |
+
|
185 |
+
def save_numpy_frames_as_gif(frames, output_path="animation.gif", duration=100):
|
186 |
+
"""
|
187 |
+
Save a list of RGB NumPy frames as a looping GIF animation.
|
188 |
+
|
189 |
+
Parameters:
|
190 |
+
frames (List[np.ndarray]): List of RGB images as uint8 NumPy arrays (shape HxWx3).
|
191 |
+
output_path (str): Path to save the output GIF.
|
192 |
+
duration (int): Duration per frame in milliseconds.
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
None
|
196 |
+
"""
|
197 |
+
pil_frames = [Image.fromarray(f) for f in frames]
|
198 |
+
pil_frames[0].save(
|
199 |
+
output_path,
|
200 |
+
save_all=True,
|
201 |
+
append_images=pil_frames[1:],
|
202 |
+
duration=duration, # duration per frame in ms
|
203 |
+
loop=0
|
204 |
+
)
|
205 |
+
print(f"GIF saved to: {output_path}")
|
206 |
+
|
207 |
+
def center_crop_frame(frame: np.ndarray, crop_fraction: float) -> np.ndarray:
|
208 |
+
"""
|
209 |
+
Crop the central region of the frame by the given fraction.
|
210 |
+
|
211 |
+
Parameters:
|
212 |
+
frame (np.ndarray): Input RGB image (H, W, 3).
|
213 |
+
crop_fraction (float): Fraction of the original size to retain (e.g., 0.8 keeps 80%).
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
np.ndarray: Cropped RGB image.
|
217 |
+
"""
|
218 |
+
if crop_fraction >= 1.0:
|
219 |
+
return frame
|
220 |
+
|
221 |
+
h, w, _ = frame.shape
|
222 |
+
new_h, new_w = int(h * crop_fraction), int(w * crop_fraction)
|
223 |
+
start_y = (h - new_h) // 2
|
224 |
+
start_x = (w - new_w) // 2
|
225 |
+
return frame[start_y:start_y + new_h, start_x:start_x + new_w, :]
|
226 |
+
|
227 |
+
|
228 |
+
|
229 |
+
def generate_smooth_closed_camera_path(existing_cameras: List[Camera], N: int = 120, d: float = 2.0, s=.25) -> List[Camera]:
|
230 |
+
"""
|
231 |
+
Generate a smooth, closed path interpolating the positions of existing cameras.
|
232 |
+
|
233 |
+
Parameters:
|
234 |
+
existing_cameras (List[Camera]): List of existing cameras.
|
235 |
+
N (int): Number of points (cameras) to sample along the smooth path.
|
236 |
+
d (float): Distance ahead for estimating the center of view.
|
237 |
+
|
238 |
+
Returns:
|
239 |
+
List[Camera]: A list of smoothly moving Camera objects along a closed loop.
|
240 |
+
"""
|
241 |
+
# Step 1: Extract camera positions
|
242 |
+
positions = np.array([cam.T for cam in existing_cameras])
|
243 |
+
|
244 |
+
# Step 2: Estimate center of view
|
245 |
+
look_targets = [cam.T + cam.R[:, 2] * d for cam in existing_cameras]
|
246 |
+
center_of_view = np.mean(look_targets, axis=0)
|
247 |
+
|
248 |
+
# Step 3: Fit a smooth closed spline through the positions
|
249 |
+
positions = np.vstack([positions, positions[0]]) # close the loop
|
250 |
+
tck, u = splprep(positions.T, s=s, per=True) # periodic=True for closed loop
|
251 |
+
|
252 |
+
# Step 4: Sample points along the spline
|
253 |
+
u_fine = np.linspace(0, 1, N)
|
254 |
+
smooth_path = np.stack(splev(u_fine, tck), axis=-1)
|
255 |
+
|
256 |
+
# Step 5: Generate cameras along the smooth path
|
257 |
+
reference_cam = existing_cameras[0]
|
258 |
+
new_cameras = []
|
259 |
+
|
260 |
+
for i, pos in enumerate(smooth_path):
|
261 |
+
R = look_at_rotation(pos, center_of_view)
|
262 |
+
new_cameras.append(Camera(
|
263 |
+
R=R,
|
264 |
+
T=pos,
|
265 |
+
FoVx=reference_cam.FoVx,
|
266 |
+
FoVy=reference_cam.FoVy,
|
267 |
+
resolution=(reference_cam.image_width, reference_cam.image_height),
|
268 |
+
colmap_id=-1,
|
269 |
+
depth_params=None,
|
270 |
+
image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
|
271 |
+
invdepthmap=None,
|
272 |
+
image_name=f"smooth_path_i={i}",
|
273 |
+
uid=i
|
274 |
+
))
|
275 |
+
|
276 |
+
return new_cameras
|
277 |
+
|
278 |
+
|
279 |
+
def save_numpy_frames_as_mp4(frames, output_path="animation.mp4", fps=10, center_crop: float = 1.0):
|
280 |
+
"""
|
281 |
+
Save a list of RGB NumPy frames as an MP4 video with optional center cropping.
|
282 |
+
|
283 |
+
Parameters:
|
284 |
+
frames (List[np.ndarray]): List of RGB images as uint8 NumPy arrays (shape HxWx3).
|
285 |
+
output_path (str): Path to save the output MP4.
|
286 |
+
fps (int): Frames per second for playback speed.
|
287 |
+
center_crop (float): Fraction (0 < center_crop <= 1.0) of central region to retain.
|
288 |
+
Use 1.0 for no cropping; 0.8 to crop to 80% center region.
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
None
|
292 |
+
"""
|
293 |
+
with imageio.get_writer(output_path, fps=fps, codec='libx264', quality=8) as writer:
|
294 |
+
for frame in frames:
|
295 |
+
cropped = center_crop_frame(frame, center_crop)
|
296 |
+
writer.append_data(cropped)
|
297 |
+
print(f"MP4 saved to: {output_path}")
|
298 |
+
|
299 |
+
|
300 |
+
|
301 |
+
def put_text_on_image(img: np.ndarray, text: str) -> np.ndarray:
|
302 |
+
"""
|
303 |
+
Draws multiline white text on a copy of the input image, positioned near the bottom
|
304 |
+
and around 80% of the image width. Handles '\n' characters to split text into multiple lines.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
img (np.ndarray): Input image as a (H, W, 3) uint8 numpy array.
|
308 |
+
text (str): Text string to draw on the image. Newlines '\n' are treated as line breaks.
|
309 |
+
|
310 |
+
Returns:
|
311 |
+
np.ndarray: The output image with the text drawn on it.
|
312 |
+
|
313 |
+
Notes:
|
314 |
+
- The function automatically adjusts line spacing and prevents text from going outside the image.
|
315 |
+
- Text is drawn in white with small font size (0.5) for minimal visual impact.
|
316 |
+
"""
|
317 |
+
img = img.copy()
|
318 |
+
height, width, _ = img.shape
|
319 |
+
|
320 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
321 |
+
font_scale = 1.
|
322 |
+
color = (255, 255, 255)
|
323 |
+
thickness = 2
|
324 |
+
line_spacing = 5 # extra pixels between lines
|
325 |
+
|
326 |
+
lines = text.split('\n')
|
327 |
+
|
328 |
+
# Precompute the maximum text width to adjust starting x
|
329 |
+
max_text_width = max(cv2.getTextSize(line, font, font_scale, thickness)[0][0] for line in lines)
|
330 |
+
|
331 |
+
x = int(0.8 * width)
|
332 |
+
x = min(x, width - max_text_width - 30) # margin on right
|
333 |
+
#x = int(0.03 * width)
|
334 |
+
|
335 |
+
# Start near the bottom, but move up depending on number of lines
|
336 |
+
total_text_height = len(lines) * (cv2.getTextSize('A', font, font_scale, thickness)[0][1] + line_spacing)
|
337 |
+
y_start = int(height*0.9) - total_text_height # 30 pixels from bottom
|
338 |
+
|
339 |
+
for i, line in enumerate(lines):
|
340 |
+
y = y_start + i * (cv2.getTextSize(line, font, font_scale, thickness)[0][1] + line_spacing)
|
341 |
+
cv2.putText(img, line, (x, y), font, font_scale, color, thickness, cv2.LINE_AA)
|
342 |
+
|
343 |
+
return img
|
344 |
+
|
345 |
+
|
346 |
+
|
347 |
+
|
348 |
+
def catmull_rom_spline(P0, P1, P2, P3, n_points=20):
|
349 |
+
"""
|
350 |
+
Compute Catmull-Rom spline segment between P1 and P2.
|
351 |
+
"""
|
352 |
+
t = np.linspace(0, 1, n_points)[:, None]
|
353 |
+
|
354 |
+
M = 0.5 * np.array([
|
355 |
+
[-1, 3, -3, 1],
|
356 |
+
[ 2, -5, 4, -1],
|
357 |
+
[-1, 0, 1, 0],
|
358 |
+
[ 0, 2, 0, 0]
|
359 |
+
])
|
360 |
+
|
361 |
+
G = np.stack([P0, P1, P2, P3], axis=0)
|
362 |
+
T = np.concatenate([t**3, t**2, t, np.ones_like(t)], axis=1)
|
363 |
+
|
364 |
+
return T @ M @ G
|
365 |
+
|
366 |
+
def sort_cameras_pca(existing_cameras: List[Camera]):
|
367 |
+
"""
|
368 |
+
Sort cameras along the main PCA axis.
|
369 |
+
"""
|
370 |
+
positions = np.array([cam.T for cam in existing_cameras])
|
371 |
+
pca = PCA(n_components=1)
|
372 |
+
scores = pca.fit_transform(positions)
|
373 |
+
sorted_indices = np.argsort(scores[:, 0])
|
374 |
+
return sorted_indices
|
375 |
+
|
376 |
+
def generate_fully_smooth_cameras(existing_cameras: List[Camera],
|
377 |
+
n_selected: int = 30,
|
378 |
+
n_points_per_segment: int = 20,
|
379 |
+
d: float = 2.0,
|
380 |
+
closed: bool = False) -> List[Camera]:
|
381 |
+
"""
|
382 |
+
Generate a fully smooth camera path using PCA ordering, global Catmull-Rom spline for positions, and global SLERP for orientations.
|
383 |
+
|
384 |
+
Args:
|
385 |
+
existing_cameras (List[Camera]): List of input cameras.
|
386 |
+
n_selected (int): Number of cameras to select after sorting.
|
387 |
+
n_points_per_segment (int): Number of interpolated points per spline segment.
|
388 |
+
d (float): Distance ahead for estimating center of view.
|
389 |
+
closed (bool): Whether to close the path.
|
390 |
+
|
391 |
+
Returns:
|
392 |
+
List[Camera]: List of smoothly moving Camera objects.
|
393 |
+
"""
|
394 |
+
# 1. Sort cameras along PCA axis
|
395 |
+
sorted_indices = sort_cameras_pca(existing_cameras)
|
396 |
+
sorted_cameras = [existing_cameras[i] for i in sorted_indices]
|
397 |
+
positions = np.array([cam.T for cam in sorted_cameras])
|
398 |
+
|
399 |
+
# 2. Subsample uniformly
|
400 |
+
idx = np.linspace(0, len(positions) - 1, n_selected).astype(int)
|
401 |
+
sampled_positions = positions[idx]
|
402 |
+
sampled_cameras = [sorted_cameras[i] for i in idx]
|
403 |
+
|
404 |
+
# 3. Prepare for Catmull-Rom
|
405 |
+
if closed:
|
406 |
+
sampled_positions = np.vstack([sampled_positions[-1], sampled_positions, sampled_positions[0], sampled_positions[1]])
|
407 |
+
else:
|
408 |
+
sampled_positions = np.vstack([sampled_positions[0], sampled_positions, sampled_positions[-1], sampled_positions[-1]])
|
409 |
+
|
410 |
+
# 4. Generate smooth path positions
|
411 |
+
path_positions = []
|
412 |
+
for i in range(1, len(sampled_positions) - 2):
|
413 |
+
segment = catmull_rom_spline(sampled_positions[i-1], sampled_positions[i], sampled_positions[i+1], sampled_positions[i+2], n_points_per_segment)
|
414 |
+
path_positions.append(segment)
|
415 |
+
path_positions = np.concatenate(path_positions, axis=0)
|
416 |
+
|
417 |
+
# 5. Global SLERP for rotations
|
418 |
+
rotations = R.from_matrix([cam.R for cam in sampled_cameras])
|
419 |
+
key_times = np.linspace(0, 1, len(rotations))
|
420 |
+
slerp = Slerp(key_times, rotations)
|
421 |
+
|
422 |
+
query_times = np.linspace(0, 1, len(path_positions))
|
423 |
+
interpolated_rotations = slerp(query_times)
|
424 |
+
|
425 |
+
# 6. Generate Camera objects
|
426 |
+
reference_cam = existing_cameras[0]
|
427 |
+
smooth_cameras = []
|
428 |
+
|
429 |
+
for i, pos in enumerate(path_positions):
|
430 |
+
R_interp = interpolated_rotations[i].as_matrix()
|
431 |
+
|
432 |
+
smooth_cameras.append(Camera(
|
433 |
+
R=R_interp,
|
434 |
+
T=pos,
|
435 |
+
FoVx=reference_cam.FoVx,
|
436 |
+
FoVy=reference_cam.FoVy,
|
437 |
+
resolution=(reference_cam.image_width, reference_cam.image_height),
|
438 |
+
colmap_id=-1,
|
439 |
+
depth_params=None,
|
440 |
+
image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
|
441 |
+
invdepthmap=None,
|
442 |
+
image_name=f"fully_smooth_path_i={i}",
|
443 |
+
uid=i
|
444 |
+
))
|
445 |
+
|
446 |
+
return smooth_cameras
|
447 |
+
|
448 |
+
|
449 |
+
def plot_cameras_and_smooth_path_with_orientation(existing_cameras: List[Camera], smooth_cameras: List[Camera], scale: float = 0.1):
|
450 |
+
"""
|
451 |
+
Plot input cameras and smooth path cameras with their orientations in 3D.
|
452 |
+
|
453 |
+
Args:
|
454 |
+
existing_cameras (List[Camera]): List of original input cameras.
|
455 |
+
smooth_cameras (List[Camera]): List of smooth path cameras.
|
456 |
+
scale (float): Length of orientation arrows.
|
457 |
+
|
458 |
+
Returns:
|
459 |
+
None
|
460 |
+
"""
|
461 |
+
# Input cameras
|
462 |
+
input_positions = np.array([cam.T for cam in existing_cameras])
|
463 |
+
|
464 |
+
# Smooth cameras
|
465 |
+
smooth_positions = np.array([cam.T for cam in smooth_cameras])
|
466 |
+
|
467 |
+
fig = go.Figure()
|
468 |
+
|
469 |
+
# Plot input camera positions
|
470 |
+
fig.add_trace(go.Scatter3d(
|
471 |
+
x=input_positions[:, 0], y=input_positions[:, 1], z=input_positions[:, 2],
|
472 |
+
mode='markers',
|
473 |
+
marker=dict(size=4, color='blue'),
|
474 |
+
name='Input Cameras'
|
475 |
+
))
|
476 |
+
|
477 |
+
# Plot smooth path positions
|
478 |
+
fig.add_trace(go.Scatter3d(
|
479 |
+
x=smooth_positions[:, 0], y=smooth_positions[:, 1], z=smooth_positions[:, 2],
|
480 |
+
mode='lines+markers',
|
481 |
+
line=dict(color='red', width=3),
|
482 |
+
marker=dict(size=2, color='red'),
|
483 |
+
name='Smooth Path Cameras'
|
484 |
+
))
|
485 |
+
|
486 |
+
# Plot input camera orientations
|
487 |
+
for cam in existing_cameras:
|
488 |
+
origin = cam.T
|
489 |
+
forward = cam.R[:, 2] # Forward direction
|
490 |
+
|
491 |
+
fig.add_trace(go.Cone(
|
492 |
+
x=[origin[0]], y=[origin[1]], z=[origin[2]],
|
493 |
+
u=[forward[0]], v=[forward[1]], w=[forward[2]],
|
494 |
+
colorscale=[[0, 'blue'], [1, 'blue']],
|
495 |
+
sizemode="absolute",
|
496 |
+
sizeref=scale,
|
497 |
+
anchor="tail",
|
498 |
+
showscale=False,
|
499 |
+
name='Input Camera Direction'
|
500 |
+
))
|
501 |
+
|
502 |
+
# Plot smooth camera orientations
|
503 |
+
for cam in smooth_cameras:
|
504 |
+
origin = cam.T
|
505 |
+
forward = cam.R[:, 2] # Forward direction
|
506 |
+
|
507 |
+
fig.add_trace(go.Cone(
|
508 |
+
x=[origin[0]], y=[origin[1]], z=[origin[2]],
|
509 |
+
u=[forward[0]], v=[forward[1]], w=[forward[2]],
|
510 |
+
colorscale=[[0, 'red'], [1, 'red']],
|
511 |
+
sizemode="absolute",
|
512 |
+
sizeref=scale,
|
513 |
+
anchor="tail",
|
514 |
+
showscale=False,
|
515 |
+
name='Smooth Camera Direction'
|
516 |
+
))
|
517 |
+
|
518 |
+
fig.update_layout(
|
519 |
+
scene=dict(
|
520 |
+
xaxis_title='X',
|
521 |
+
yaxis_title='Y',
|
522 |
+
zaxis_title='Z',
|
523 |
+
aspectmode='data'
|
524 |
+
),
|
525 |
+
title="Input Cameras and Smooth Path with Orientations",
|
526 |
+
margin=dict(l=0, r=0, b=0, t=30)
|
527 |
+
)
|
528 |
+
|
529 |
+
fig.show()
|
530 |
+
|
531 |
+
|
532 |
+
def solve_tsp_nearest_neighbor(points: np.ndarray):
|
533 |
+
"""
|
534 |
+
Solve TSP approximately using nearest neighbor heuristic.
|
535 |
+
|
536 |
+
Args:
|
537 |
+
points (np.ndarray): (N, 3) array of points.
|
538 |
+
|
539 |
+
Returns:
|
540 |
+
List[int]: Optimal visiting order of points.
|
541 |
+
"""
|
542 |
+
N = points.shape[0]
|
543 |
+
dist = distance_matrix(points, points)
|
544 |
+
visited = [0]
|
545 |
+
unvisited = set(range(1, N))
|
546 |
+
|
547 |
+
while unvisited:
|
548 |
+
last = visited[-1]
|
549 |
+
next_city = min(unvisited, key=lambda city: dist[last, city])
|
550 |
+
visited.append(next_city)
|
551 |
+
unvisited.remove(next_city)
|
552 |
+
|
553 |
+
return visited
|
554 |
+
|
555 |
+
def solve_tsp_2opt(points: np.ndarray, n_iter: int = 1000) -> np.ndarray:
|
556 |
+
"""
|
557 |
+
Solve TSP approximately using Nearest Neighbor + 2-Opt.
|
558 |
+
|
559 |
+
Args:
|
560 |
+
points (np.ndarray): Array of shape (N, D) with points.
|
561 |
+
n_iter (int): Number of 2-opt iterations.
|
562 |
+
|
563 |
+
Returns:
|
564 |
+
np.ndarray: Ordered list of indices.
|
565 |
+
"""
|
566 |
+
n_points = points.shape[0]
|
567 |
+
|
568 |
+
# === 1. Start with Nearest Neighbor
|
569 |
+
unvisited = list(range(n_points))
|
570 |
+
current = unvisited.pop(0)
|
571 |
+
path = [current]
|
572 |
+
|
573 |
+
while unvisited:
|
574 |
+
dists = np.linalg.norm(points[unvisited] - points[current], axis=1)
|
575 |
+
next_idx = unvisited[np.argmin(dists)]
|
576 |
+
unvisited.remove(next_idx)
|
577 |
+
path.append(next_idx)
|
578 |
+
current = next_idx
|
579 |
+
|
580 |
+
# === 2. Apply 2-Opt improvements
|
581 |
+
def path_length(path):
|
582 |
+
return np.sum(np.linalg.norm(points[path[i]] - points[path[i+1]], axis=0) for i in range(len(path)-1))
|
583 |
+
|
584 |
+
best_length = path_length(path)
|
585 |
+
improved = True
|
586 |
+
|
587 |
+
for _ in range(n_iter):
|
588 |
+
if not improved:
|
589 |
+
break
|
590 |
+
improved = False
|
591 |
+
for i in range(1, n_points - 2):
|
592 |
+
for j in range(i + 1, n_points):
|
593 |
+
if j - i == 1: continue
|
594 |
+
new_path = path[:i] + path[i:j][::-1] + path[j:]
|
595 |
+
new_length = path_length(new_path)
|
596 |
+
if new_length < best_length:
|
597 |
+
path = new_path
|
598 |
+
best_length = new_length
|
599 |
+
improved = True
|
600 |
+
break
|
601 |
+
if improved:
|
602 |
+
break
|
603 |
+
|
604 |
+
return np.array(path)
|
605 |
+
|
606 |
+
def generate_fully_smooth_cameras_with_tsp(existing_cameras: List[Camera],
|
607 |
+
n_selected: int = 30,
|
608 |
+
n_points_per_segment: int = 20,
|
609 |
+
d: float = 2.0,
|
610 |
+
closed: bool = False) -> List[Camera]:
|
611 |
+
"""
|
612 |
+
Generate a fully smooth camera path using TSP ordering, global Catmull-Rom spline for positions, and global SLERP for orientations.
|
613 |
+
|
614 |
+
Args:
|
615 |
+
existing_cameras (List[Camera]): List of input cameras.
|
616 |
+
n_selected (int): Number of cameras to select after ordering.
|
617 |
+
n_points_per_segment (int): Number of interpolated points per spline segment.
|
618 |
+
d (float): Distance ahead for estimating center of view.
|
619 |
+
closed (bool): Whether to close the path.
|
620 |
+
|
621 |
+
Returns:
|
622 |
+
List[Camera]: List of smoothly moving Camera objects.
|
623 |
+
"""
|
624 |
+
positions = np.array([cam.T for cam in existing_cameras])
|
625 |
+
|
626 |
+
# 1. Solve approximate TSP
|
627 |
+
order = solve_tsp_nearest_neighbor(positions)
|
628 |
+
ordered_cameras = [existing_cameras[i] for i in order]
|
629 |
+
ordered_positions = positions[order]
|
630 |
+
|
631 |
+
# 2. Subsample uniformly
|
632 |
+
idx = np.linspace(0, len(ordered_positions) - 1, n_selected).astype(int)
|
633 |
+
sampled_positions = ordered_positions[idx]
|
634 |
+
sampled_cameras = [ordered_cameras[i] for i in idx]
|
635 |
+
|
636 |
+
# 3. Prepare for Catmull-Rom
|
637 |
+
if closed:
|
638 |
+
sampled_positions = np.vstack([sampled_positions[-1], sampled_positions, sampled_positions[0], sampled_positions[1]])
|
639 |
+
else:
|
640 |
+
sampled_positions = np.vstack([sampled_positions[0], sampled_positions, sampled_positions[-1], sampled_positions[-1]])
|
641 |
+
|
642 |
+
# 4. Generate smooth path positions
|
643 |
+
path_positions = []
|
644 |
+
for i in range(1, len(sampled_positions) - 2):
|
645 |
+
segment = catmull_rom_spline(sampled_positions[i-1], sampled_positions[i], sampled_positions[i+1], sampled_positions[i+2], n_points_per_segment)
|
646 |
+
path_positions.append(segment)
|
647 |
+
path_positions = np.concatenate(path_positions, axis=0)
|
648 |
+
|
649 |
+
# 5. Global SLERP for rotations
|
650 |
+
rotations = R.from_matrix([cam.R for cam in sampled_cameras])
|
651 |
+
key_times = np.linspace(0, 1, len(rotations))
|
652 |
+
slerp = Slerp(key_times, rotations)
|
653 |
+
|
654 |
+
query_times = np.linspace(0, 1, len(path_positions))
|
655 |
+
interpolated_rotations = slerp(query_times)
|
656 |
+
|
657 |
+
# 6. Generate Camera objects
|
658 |
+
reference_cam = existing_cameras[0]
|
659 |
+
smooth_cameras = []
|
660 |
+
|
661 |
+
for i, pos in enumerate(path_positions):
|
662 |
+
R_interp = interpolated_rotations[i].as_matrix()
|
663 |
+
|
664 |
+
smooth_cameras.append(Camera(
|
665 |
+
R=R_interp,
|
666 |
+
T=pos,
|
667 |
+
FoVx=reference_cam.FoVx,
|
668 |
+
FoVy=reference_cam.FoVy,
|
669 |
+
resolution=(reference_cam.image_width, reference_cam.image_height),
|
670 |
+
colmap_id=-1,
|
671 |
+
depth_params=None,
|
672 |
+
image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
|
673 |
+
invdepthmap=None,
|
674 |
+
image_name=f"fully_smooth_path_i={i}",
|
675 |
+
uid=i
|
676 |
+
))
|
677 |
+
|
678 |
+
return smooth_cameras
|
679 |
+
|
680 |
+
from typing import List
|
681 |
+
import numpy as np
|
682 |
+
from sklearn.mixture import GaussianMixture
|
683 |
+
from scipy.spatial.transform import Rotation as R, Slerp
|
684 |
+
from PIL import Image
|
685 |
+
|
686 |
+
def generate_clustered_smooth_cameras_with_tsp(existing_cameras: List[Camera],
|
687 |
+
n_selected: int = 30,
|
688 |
+
n_points_per_segment: int = 20,
|
689 |
+
d: float = 2.0,
|
690 |
+
n_clusters: int = 5,
|
691 |
+
closed: bool = False) -> List[Camera]:
|
692 |
+
"""
|
693 |
+
Generate a fully smooth camera path using clustering + TSP between nearest cluster centers + TSP inside clusters.
|
694 |
+
Positions are normalized before clustering and denormalized before generating final cameras.
|
695 |
+
|
696 |
+
Args:
|
697 |
+
existing_cameras (List[Camera]): List of input cameras.
|
698 |
+
n_selected (int): Number of cameras to select after ordering.
|
699 |
+
n_points_per_segment (int): Number of interpolated points per spline segment.
|
700 |
+
d (float): Distance ahead for estimating center of view.
|
701 |
+
n_clusters (int): Number of GMM clusters.
|
702 |
+
closed (bool): Whether to close the path.
|
703 |
+
|
704 |
+
Returns:
|
705 |
+
List[Camera]: Smooth path of Camera objects.
|
706 |
+
"""
|
707 |
+
# Extract positions and rotations
|
708 |
+
positions = np.array([cam.T for cam in existing_cameras])
|
709 |
+
rotations = np.array([R.from_matrix(cam.R).as_quat() for cam in existing_cameras])
|
710 |
+
|
711 |
+
# === Normalize positions
|
712 |
+
mean_pos = np.mean(positions, axis=0)
|
713 |
+
scale_pos = np.std(positions, axis=0)
|
714 |
+
scale_pos[scale_pos == 0] = 1.0 # avoid division by zero
|
715 |
+
|
716 |
+
positions_normalized = (positions - mean_pos) / scale_pos
|
717 |
+
|
718 |
+
# === Features for clustering (only positions, not rotations)
|
719 |
+
features = positions_normalized
|
720 |
+
|
721 |
+
# === 1. GMM clustering
|
722 |
+
gmm = GaussianMixture(n_components=n_clusters, covariance_type='full', random_state=42)
|
723 |
+
cluster_labels = gmm.fit_predict(features)
|
724 |
+
|
725 |
+
clusters = {}
|
726 |
+
cluster_centers = []
|
727 |
+
|
728 |
+
for cluster_id in range(n_clusters):
|
729 |
+
cluster_indices = np.where(cluster_labels == cluster_id)[0]
|
730 |
+
if len(cluster_indices) == 0:
|
731 |
+
continue
|
732 |
+
clusters[cluster_id] = cluster_indices
|
733 |
+
cluster_center = np.mean(features[cluster_indices], axis=0)
|
734 |
+
cluster_centers.append(cluster_center)
|
735 |
+
|
736 |
+
cluster_centers = np.stack(cluster_centers)
|
737 |
+
|
738 |
+
# === 2. Remap cluster centers to nearest existing cameras
|
739 |
+
if False:
|
740 |
+
mapped_centers = []
|
741 |
+
for center in cluster_centers:
|
742 |
+
dists = np.linalg.norm(features - center, axis=1)
|
743 |
+
nearest_idx = np.argmin(dists)
|
744 |
+
mapped_centers.append(features[nearest_idx])
|
745 |
+
mapped_centers = np.stack(mapped_centers)
|
746 |
+
cluster_centers = mapped_centers
|
747 |
+
# === 3. Solve TSP between mapped cluster centers
|
748 |
+
cluster_order = solve_tsp_2opt(cluster_centers)
|
749 |
+
|
750 |
+
# === 4. For each cluster, solve TSP inside cluster
|
751 |
+
final_indices = []
|
752 |
+
for cluster_id in cluster_order:
|
753 |
+
cluster_indices = clusters[cluster_id]
|
754 |
+
cluster_positions = features[cluster_indices]
|
755 |
+
|
756 |
+
if len(cluster_positions) == 1:
|
757 |
+
final_indices.append(cluster_indices[0])
|
758 |
+
continue
|
759 |
+
|
760 |
+
local_order = solve_tsp_nearest_neighbor(cluster_positions)
|
761 |
+
ordered_cluster_indices = cluster_indices[local_order]
|
762 |
+
final_indices.extend(ordered_cluster_indices)
|
763 |
+
|
764 |
+
ordered_cameras = [existing_cameras[i] for i in final_indices]
|
765 |
+
ordered_positions = positions_normalized[final_indices]
|
766 |
+
|
767 |
+
# === 5. Subsample uniformly
|
768 |
+
idx = np.linspace(0, len(ordered_positions) - 1, n_selected).astype(int)
|
769 |
+
sampled_positions = ordered_positions[idx]
|
770 |
+
sampled_cameras = [ordered_cameras[i] for i in idx]
|
771 |
+
|
772 |
+
# === 6. Prepare for Catmull-Rom spline
|
773 |
+
if closed:
|
774 |
+
sampled_positions = np.vstack([sampled_positions[-1], sampled_positions, sampled_positions[0], sampled_positions[1]])
|
775 |
+
else:
|
776 |
+
sampled_positions = np.vstack([sampled_positions[0], sampled_positions, sampled_positions[-1], sampled_positions[-1]])
|
777 |
+
|
778 |
+
# === 7. Smooth path positions
|
779 |
+
path_positions = []
|
780 |
+
for i in range(1, len(sampled_positions) - 2):
|
781 |
+
segment = catmull_rom_spline(sampled_positions[i-1], sampled_positions[i], sampled_positions[i+1], sampled_positions[i+2], n_points_per_segment)
|
782 |
+
path_positions.append(segment)
|
783 |
+
path_positions = np.concatenate(path_positions, axis=0)
|
784 |
+
|
785 |
+
# === 8. Denormalize
|
786 |
+
path_positions = path_positions * scale_pos + mean_pos
|
787 |
+
|
788 |
+
# === 9. SLERP for rotations
|
789 |
+
rotations = R.from_matrix([cam.R for cam in sampled_cameras])
|
790 |
+
key_times = np.linspace(0, 1, len(rotations))
|
791 |
+
slerp = Slerp(key_times, rotations)
|
792 |
+
|
793 |
+
query_times = np.linspace(0, 1, len(path_positions))
|
794 |
+
interpolated_rotations = slerp(query_times)
|
795 |
+
|
796 |
+
# === 10. Generate Camera objects
|
797 |
+
reference_cam = existing_cameras[0]
|
798 |
+
smooth_cameras = []
|
799 |
+
|
800 |
+
for i, pos in enumerate(path_positions):
|
801 |
+
R_interp = interpolated_rotations[i].as_matrix()
|
802 |
+
|
803 |
+
smooth_cameras.append(Camera(
|
804 |
+
R=R_interp,
|
805 |
+
T=pos,
|
806 |
+
FoVx=reference_cam.FoVx,
|
807 |
+
FoVy=reference_cam.FoVy,
|
808 |
+
resolution=(reference_cam.image_width, reference_cam.image_height),
|
809 |
+
colmap_id=-1,
|
810 |
+
depth_params=None,
|
811 |
+
image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
|
812 |
+
invdepthmap=None,
|
813 |
+
image_name=f"clustered_smooth_path_i={i}",
|
814 |
+
uid=i
|
815 |
+
))
|
816 |
+
|
817 |
+
return smooth_cameras
|
818 |
+
|
819 |
+
|
820 |
+
# def generate_clustered_path(existing_cameras: List[Camera],
|
821 |
+
# n_points_per_segment: int = 20,
|
822 |
+
# d: float = 2.0,
|
823 |
+
# n_clusters: int = 5,
|
824 |
+
# closed: bool = False) -> List[Camera]:
|
825 |
+
# """
|
826 |
+
# Generate a smooth camera path using GMM clustering and TSP on cluster centers.
|
827 |
+
|
828 |
+
# Args:
|
829 |
+
# existing_cameras (List[Camera]): List of input cameras.
|
830 |
+
# n_points_per_segment (int): Number of interpolated points per spline segment.
|
831 |
+
# d (float): Distance ahead for estimating center of view.
|
832 |
+
# n_clusters (int): Number of GMM clusters (zones).
|
833 |
+
# closed (bool): Whether to close the path.
|
834 |
+
|
835 |
+
# Returns:
|
836 |
+
# List[Camera]: Smooth path of Camera objects.
|
837 |
+
# """
|
838 |
+
# # Extract positions and rotations
|
839 |
+
# positions = np.array([cam.T for cam in existing_cameras])
|
840 |
+
|
841 |
+
# # === Normalize positions
|
842 |
+
# mean_pos = np.mean(positions, axis=0)
|
843 |
+
# scale_pos = np.std(positions, axis=0)
|
844 |
+
# scale_pos[scale_pos == 0] = 1.0
|
845 |
+
|
846 |
+
# positions_normalized = (positions - mean_pos) / scale_pos
|
847 |
+
|
848 |
+
# # === 1. GMM clustering (only positions)
|
849 |
+
# gmm = GaussianMixture(n_components=n_clusters, covariance_type='full', random_state=42)
|
850 |
+
# cluster_labels = gmm.fit_predict(positions_normalized)
|
851 |
+
|
852 |
+
# cluster_centers = []
|
853 |
+
# for cluster_id in range(n_clusters):
|
854 |
+
# cluster_indices = np.where(cluster_labels == cluster_id)[0]
|
855 |
+
# if len(cluster_indices) == 0:
|
856 |
+
# continue
|
857 |
+
# cluster_center = np.mean(positions_normalized[cluster_indices], axis=0)
|
858 |
+
# cluster_centers.append(cluster_center)
|
859 |
+
|
860 |
+
# cluster_centers = np.stack(cluster_centers)
|
861 |
+
|
862 |
+
# # === 2. Solve TSP between cluster centers
|
863 |
+
# cluster_order = solve_tsp_2opt(cluster_centers)
|
864 |
+
|
865 |
+
# # === 3. Reorder cluster centers
|
866 |
+
# ordered_centers = cluster_centers[cluster_order]
|
867 |
+
|
868 |
+
# # === 4. Prepare Catmull-Rom spline
|
869 |
+
# if closed:
|
870 |
+
# ordered_centers = np.vstack([ordered_centers[-1], ordered_centers, ordered_centers[0], ordered_centers[1]])
|
871 |
+
# else:
|
872 |
+
# ordered_centers = np.vstack([ordered_centers[0], ordered_centers, ordered_centers[-1], ordered_centers[-1]])
|
873 |
+
|
874 |
+
# # === 5. Generate smooth path positions
|
875 |
+
# path_positions = []
|
876 |
+
# for i in range(1, len(ordered_centers) - 2):
|
877 |
+
# segment = catmull_rom_spline(ordered_centers[i-1], ordered_centers[i], ordered_centers[i+1], ordered_centers[i+2], n_points_per_segment)
|
878 |
+
# path_positions.append(segment)
|
879 |
+
# path_positions = np.concatenate(path_positions, axis=0)
|
880 |
+
|
881 |
+
# # === 6. Denormalize back
|
882 |
+
# path_positions = path_positions * scale_pos + mean_pos
|
883 |
+
|
884 |
+
# # === 7. Generate dummy rotations (constant forward facing)
|
885 |
+
# reference_cam = existing_cameras[0]
|
886 |
+
# default_rotation = R.from_matrix(reference_cam.R)
|
887 |
+
|
888 |
+
# # For simplicity, fixed rotation for all
|
889 |
+
# smooth_cameras = []
|
890 |
+
|
891 |
+
# for i, pos in enumerate(path_positions):
|
892 |
+
# R_interp = default_rotation.as_matrix()
|
893 |
+
|
894 |
+
# smooth_cameras.append(Camera(
|
895 |
+
# R=R_interp,
|
896 |
+
# T=pos,
|
897 |
+
# FoVx=reference_cam.FoVx,
|
898 |
+
# FoVy=reference_cam.FoVy,
|
899 |
+
# resolution=(reference_cam.image_width, reference_cam.image_height),
|
900 |
+
# colmap_id=-1,
|
901 |
+
# depth_params=None,
|
902 |
+
# image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
|
903 |
+
# invdepthmap=None,
|
904 |
+
# image_name=f"cluster_path_i={i}",
|
905 |
+
# uid=i
|
906 |
+
# ))
|
907 |
+
|
908 |
+
# return smooth_cameras
|
909 |
+
|
910 |
+
from typing import List
|
911 |
+
import numpy as np
|
912 |
+
from sklearn.cluster import KMeans
|
913 |
+
from scipy.spatial.transform import Rotation as R, Slerp
|
914 |
+
from PIL import Image
|
915 |
+
|
916 |
+
def generate_clustered_path(existing_cameras: List[Camera],
|
917 |
+
n_points_per_segment: int = 20,
|
918 |
+
d: float = 2.0,
|
919 |
+
n_clusters: int = 5,
|
920 |
+
closed: bool = False) -> List[Camera]:
|
921 |
+
"""
|
922 |
+
Generate a smooth camera path using K-Means clustering and TSP on cluster centers.
|
923 |
+
|
924 |
+
Args:
|
925 |
+
existing_cameras (List[Camera]): List of input cameras.
|
926 |
+
n_points_per_segment (int): Number of interpolated points per spline segment.
|
927 |
+
d (float): Distance ahead for estimating center of view.
|
928 |
+
n_clusters (int): Number of KMeans clusters (zones).
|
929 |
+
closed (bool): Whether to close the path.
|
930 |
+
|
931 |
+
Returns:
|
932 |
+
List[Camera]: Smooth path of Camera objects.
|
933 |
+
"""
|
934 |
+
# Extract positions
|
935 |
+
positions = np.array([cam.T for cam in existing_cameras])
|
936 |
+
|
937 |
+
# === Normalize positions
|
938 |
+
mean_pos = np.mean(positions, axis=0)
|
939 |
+
scale_pos = np.std(positions, axis=0)
|
940 |
+
scale_pos[scale_pos == 0] = 1.0
|
941 |
+
|
942 |
+
positions_normalized = (positions - mean_pos) / scale_pos
|
943 |
+
|
944 |
+
# === 1. K-Means clustering (only positions)
|
945 |
+
kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init='auto')
|
946 |
+
cluster_labels = kmeans.fit_predict(positions_normalized)
|
947 |
+
|
948 |
+
cluster_centers = []
|
949 |
+
for cluster_id in range(n_clusters):
|
950 |
+
cluster_indices = np.where(cluster_labels == cluster_id)[0]
|
951 |
+
if len(cluster_indices) == 0:
|
952 |
+
continue
|
953 |
+
cluster_center = np.mean(positions_normalized[cluster_indices], axis=0)
|
954 |
+
cluster_centers.append(cluster_center)
|
955 |
+
|
956 |
+
cluster_centers = np.stack(cluster_centers)
|
957 |
+
|
958 |
+
# === 2. Solve TSP between cluster centers
|
959 |
+
cluster_order = solve_tsp_2opt(cluster_centers)
|
960 |
+
|
961 |
+
# === 3. Reorder cluster centers
|
962 |
+
ordered_centers = cluster_centers[cluster_order]
|
963 |
+
|
964 |
+
# === 4. Prepare Catmull-Rom spline
|
965 |
+
if closed:
|
966 |
+
ordered_centers = np.vstack([ordered_centers[-1], ordered_centers, ordered_centers[0], ordered_centers[1]])
|
967 |
+
else:
|
968 |
+
ordered_centers = np.vstack([ordered_centers[0], ordered_centers, ordered_centers[-1], ordered_centers[-1]])
|
969 |
+
|
970 |
+
# === 5. Generate smooth path positions
|
971 |
+
path_positions = []
|
972 |
+
for i in range(1, len(ordered_centers) - 2):
|
973 |
+
segment = catmull_rom_spline(ordered_centers[i-1], ordered_centers[i], ordered_centers[i+1], ordered_centers[i+2], n_points_per_segment)
|
974 |
+
path_positions.append(segment)
|
975 |
+
path_positions = np.concatenate(path_positions, axis=0)
|
976 |
+
|
977 |
+
# === 6. Denormalize back
|
978 |
+
path_positions = path_positions * scale_pos + mean_pos
|
979 |
+
|
980 |
+
# === 7. Generate dummy rotations (constant forward facing)
|
981 |
+
reference_cam = existing_cameras[0]
|
982 |
+
default_rotation = R.from_matrix(reference_cam.R)
|
983 |
+
|
984 |
+
# For simplicity, fixed rotation for all
|
985 |
+
smooth_cameras = []
|
986 |
+
|
987 |
+
for i, pos in enumerate(path_positions):
|
988 |
+
R_interp = default_rotation.as_matrix()
|
989 |
+
|
990 |
+
smooth_cameras.append(Camera(
|
991 |
+
R=R_interp,
|
992 |
+
T=pos,
|
993 |
+
FoVx=reference_cam.FoVx,
|
994 |
+
FoVy=reference_cam.FoVy,
|
995 |
+
resolution=(reference_cam.image_width, reference_cam.image_height),
|
996 |
+
colmap_id=-1,
|
997 |
+
depth_params=None,
|
998 |
+
image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
|
999 |
+
invdepthmap=None,
|
1000 |
+
image_name=f"cluster_path_i={i}",
|
1001 |
+
uid=i
|
1002 |
+
))
|
1003 |
+
|
1004 |
+
return smooth_cameras
|
1005 |
+
|
1006 |
+
|
1007 |
+
|
1008 |
+
|
1009 |
+
def visualize_image_with_points(image, points):
|
1010 |
+
"""
|
1011 |
+
Visualize an image with points overlaid on top. This is useful for correspondences visualizations
|
1012 |
+
|
1013 |
+
Parameters:
|
1014 |
+
- image: PIL Image object
|
1015 |
+
- points: Numpy array of shape [N, 2] containing (x, y) coordinates of points
|
1016 |
+
|
1017 |
+
Returns:
|
1018 |
+
- None (displays the visualization)
|
1019 |
+
"""
|
1020 |
+
|
1021 |
+
# Convert PIL image to numpy array
|
1022 |
+
img_array = np.array(image)
|
1023 |
+
|
1024 |
+
# Create a figure and axis
|
1025 |
+
fig, ax = plt.subplots(figsize=(7,7))
|
1026 |
+
|
1027 |
+
# Display the image
|
1028 |
+
ax.imshow(img_array)
|
1029 |
+
|
1030 |
+
# Scatter plot the points on top of the image
|
1031 |
+
ax.scatter(points[:, 0], points[:, 1], color='red', marker='o', s=1)
|
1032 |
+
|
1033 |
+
# Show the plot
|
1034 |
+
plt.show()
|
1035 |
+
|
1036 |
+
|
1037 |
+
def visualize_correspondences(image1, points1, image2, points2):
|
1038 |
+
"""
|
1039 |
+
Visualize two images concatenated horizontally with key points and correspondences.
|
1040 |
+
|
1041 |
+
Parameters:
|
1042 |
+
- image1: PIL Image object (left image)
|
1043 |
+
- points1: Numpy array of shape [N, 2] containing (x, y) coordinates of key points for image1
|
1044 |
+
- image2: PIL Image object (right image)
|
1045 |
+
- points2: Numpy array of shape [N, 2] containing (x, y) coordinates of key points for image2
|
1046 |
+
|
1047 |
+
Returns:
|
1048 |
+
- None (displays the visualization)
|
1049 |
+
"""
|
1050 |
+
|
1051 |
+
# Concatenate images horizontally
|
1052 |
+
concatenated_image = np.concatenate((np.array(image1), np.array(image2)), axis=1)
|
1053 |
+
|
1054 |
+
# Create a figure and axis
|
1055 |
+
fig, ax = plt.subplots(figsize=(10,10))
|
1056 |
+
|
1057 |
+
# Display the concatenated image
|
1058 |
+
ax.imshow(concatenated_image)
|
1059 |
+
|
1060 |
+
# Plot key points on the left image
|
1061 |
+
ax.scatter(points1[:, 0], points1[:, 1], color='red', marker='o', s=10)
|
1062 |
+
|
1063 |
+
# Plot key points on the right image
|
1064 |
+
ax.scatter(points2[:, 0] + image1.width, points2[:, 1], color='blue', marker='o', s=10)
|
1065 |
+
|
1066 |
+
# Draw lines connecting corresponding key points
|
1067 |
+
for i in range(len(points1)):
|
1068 |
+
ax.plot([points1[i, 0], points2[i, 0] + image1.width], [points1[i, 1], points2[i, 1]])#, color='green')
|
1069 |
+
|
1070 |
+
# Show the plot
|
1071 |
+
plt.show()
|
1072 |
+
|
submodules/RoMa/.gitignore
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.egg-info*
|
2 |
+
*.vscode*
|
3 |
+
*__pycache__*
|
4 |
+
vis*
|
5 |
+
workspace*
|
6 |
+
.venv
|
7 |
+
.DS_Store
|
8 |
+
jobs/*
|
9 |
+
*ignore_me*
|
10 |
+
*.pth
|
11 |
+
wandb*
|
submodules/RoMa/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Johan Edstedt
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
submodules/RoMa/README.md
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
<p align="center">
|
3 |
+
<h1 align="center"> <ins>RoMa</ins> 🏛️:<br> Robust Dense Feature Matching <br> ⭐CVPR 2024⭐</h1>
|
4 |
+
<p align="center">
|
5 |
+
<a href="https://scholar.google.com/citations?user=Ul-vMR0AAAAJ">Johan Edstedt</a>
|
6 |
+
·
|
7 |
+
<a href="https://scholar.google.com/citations?user=HS2WuHkAAAAJ">Qiyu Sun</a>
|
8 |
+
·
|
9 |
+
<a href="https://scholar.google.com/citations?user=FUE3Wd0AAAAJ">Georg Bökman</a>
|
10 |
+
·
|
11 |
+
<a href="https://scholar.google.com/citations?user=6WRQpCQAAAAJ">Mårten Wadenbäck</a>
|
12 |
+
·
|
13 |
+
<a href="https://scholar.google.com/citations?user=lkWfR08AAAAJ">Michael Felsberg</a>
|
14 |
+
</p>
|
15 |
+
<h2 align="center"><p>
|
16 |
+
<a href="https://arxiv.org/abs/2305.15404" align="center">Paper</a> |
|
17 |
+
<a href="https://parskatt.github.io/RoMa" align="center">Project Page</a>
|
18 |
+
</p></h2>
|
19 |
+
<div align="center"></div>
|
20 |
+
</p>
|
21 |
+
<br/>
|
22 |
+
<p align="center">
|
23 |
+
<img src="https://github.com/Parskatt/RoMa/assets/22053118/15d8fea7-aa6d-479f-8a93-350d950d006b" alt="example" width=80%>
|
24 |
+
<br>
|
25 |
+
<em>RoMa is the robust dense feature matcher capable of estimating pixel-dense warps and reliable certainties for almost any image pair.</em>
|
26 |
+
</p>
|
27 |
+
|
28 |
+
## Setup/Install
|
29 |
+
In your python environment (tested on Linux python 3.10), run:
|
30 |
+
```bash
|
31 |
+
pip install -e .
|
32 |
+
```
|
33 |
+
## Demo / How to Use
|
34 |
+
We provide two demos in the [demos folder](demo).
|
35 |
+
Here's the gist of it:
|
36 |
+
```python
|
37 |
+
from romatch import roma_outdoor
|
38 |
+
roma_model = roma_outdoor(device=device)
|
39 |
+
# Match
|
40 |
+
warp, certainty = roma_model.match(imA_path, imB_path, device=device)
|
41 |
+
# Sample matches for estimation
|
42 |
+
matches, certainty = roma_model.sample(warp, certainty)
|
43 |
+
# Convert to pixel coordinates (RoMa produces matches in [-1,1]x[-1,1])
|
44 |
+
kptsA, kptsB = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
|
45 |
+
# Find a fundamental matrix (or anything else of interest)
|
46 |
+
F, mask = cv2.findFundamentalMat(
|
47 |
+
kptsA.cpu().numpy(), kptsB.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
|
48 |
+
)
|
49 |
+
```
|
50 |
+
|
51 |
+
**New**: You can also match arbitrary keypoints with RoMa. See [match_keypoints](romatch/models/matcher.py) in RegressionMatcher.
|
52 |
+
|
53 |
+
## Settings
|
54 |
+
|
55 |
+
### Resolution
|
56 |
+
By default RoMa uses an initial resolution of (560,560) which is then upsampled to (864,864).
|
57 |
+
You can change this at construction (see roma_outdoor kwargs).
|
58 |
+
You can also change this later, by changing the roma_model.w_resized, roma_model.h_resized, and roma_model.upsample_res.
|
59 |
+
|
60 |
+
### Sampling
|
61 |
+
roma_model.sample_thresh controls the thresholding used when sampling matches for estimation. In certain cases a lower or higher threshold may improve results.
|
62 |
+
|
63 |
+
|
64 |
+
## Reproducing Results
|
65 |
+
The experiments in the paper are provided in the [experiments folder](experiments).
|
66 |
+
|
67 |
+
### Training
|
68 |
+
1. First follow the instructions provided here: https://github.com/Parskatt/DKM for downloading and preprocessing datasets.
|
69 |
+
2. Run the relevant experiment, e.g.,
|
70 |
+
```bash
|
71 |
+
torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d experiments/roma_outdoor.py
|
72 |
+
```
|
73 |
+
### Testing
|
74 |
+
```bash
|
75 |
+
python experiments/roma_outdoor.py --only_test --benchmark mega-1500
|
76 |
+
```
|
77 |
+
## License
|
78 |
+
All our code except DINOv2 is MIT license.
|
79 |
+
DINOv2 has an Apache 2 license [DINOv2](https://github.com/facebookresearch/dinov2/blob/main/LICENSE).
|
80 |
+
|
81 |
+
## Acknowledgement
|
82 |
+
Our codebase builds on the code in [DKM](https://github.com/Parskatt/DKM).
|
83 |
+
|
84 |
+
## Tiny RoMa
|
85 |
+
If you find that RoMa is too heavy, you might want to try Tiny RoMa which is built on top of XFeat.
|
86 |
+
```python
|
87 |
+
from romatch import tiny_roma_v1_outdoor
|
88 |
+
tiny_roma_model = tiny_roma_v1_outdoor(device=device)
|
89 |
+
```
|
90 |
+
Mega1500:
|
91 |
+
| | AUC@5 | AUC@10 | AUC@20 |
|
92 |
+
|----------|----------|----------|----------|
|
93 |
+
| XFeat | 46.4 | 58.9 | 69.2 |
|
94 |
+
| XFeat* | 51.9 | 67.2 | 78.9 |
|
95 |
+
| Tiny RoMa v1 | 56.4 | 69.5 | 79.5 |
|
96 |
+
| RoMa | - | - | - |
|
97 |
+
|
98 |
+
Mega-8-Scenes (See DKM):
|
99 |
+
| | AUC@5 | AUC@10 | AUC@20 |
|
100 |
+
|----------|----------|----------|----------|
|
101 |
+
| XFeat | - | - | - |
|
102 |
+
| XFeat* | 50.1 | 64.4 | 75.2 |
|
103 |
+
| Tiny RoMa v1 | 57.7 | 70.5 | 79.6 |
|
104 |
+
| RoMa | - | - | - |
|
105 |
+
|
106 |
+
IMC22 :'):
|
107 |
+
| | mAA@10 |
|
108 |
+
|----------|----------|
|
109 |
+
| XFeat | 42.1 |
|
110 |
+
| XFeat* | - |
|
111 |
+
| Tiny RoMa v1 | 42.2 |
|
112 |
+
| RoMa | - |
|
113 |
+
|
114 |
+
## BibTeX
|
115 |
+
If you find our models useful, please consider citing our paper!
|
116 |
+
```
|
117 |
+
@article{edstedt2024roma,
|
118 |
+
title={{RoMa: Robust Dense Feature Matching}},
|
119 |
+
author={Edstedt, Johan and Sun, Qiyu and Bökman, Georg and Wadenbäck, Mårten and Felsberg, Michael},
|
120 |
+
journal={IEEE Conference on Computer Vision and Pattern Recognition},
|
121 |
+
year={2024}
|
122 |
+
}
|
123 |
+
```
|
submodules/RoMa/data/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
submodules/RoMa/demo/demo_3D_effect.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from romatch.utils.utils import tensor_to_pil
|
6 |
+
|
7 |
+
from romatch import roma_outdoor
|
8 |
+
|
9 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
10 |
+
if torch.backends.mps.is_available():
|
11 |
+
device = torch.device('mps')
|
12 |
+
|
13 |
+
if __name__ == "__main__":
|
14 |
+
from argparse import ArgumentParser
|
15 |
+
parser = ArgumentParser()
|
16 |
+
parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
|
17 |
+
parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
|
18 |
+
parser.add_argument("--save_path", default="demo/gif/roma_warp_toronto", type=str)
|
19 |
+
|
20 |
+
args, _ = parser.parse_known_args()
|
21 |
+
im1_path = args.im_A_path
|
22 |
+
im2_path = args.im_B_path
|
23 |
+
save_path = args.save_path
|
24 |
+
|
25 |
+
# Create model
|
26 |
+
roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
|
27 |
+
roma_model.symmetric = False
|
28 |
+
|
29 |
+
H, W = roma_model.get_output_resolution()
|
30 |
+
|
31 |
+
im1 = Image.open(im1_path).resize((W, H))
|
32 |
+
im2 = Image.open(im2_path).resize((W, H))
|
33 |
+
|
34 |
+
# Match
|
35 |
+
warp, certainty = roma_model.match(im1_path, im2_path, device=device)
|
36 |
+
# Sampling not needed, but can be done with model.sample(warp, certainty)
|
37 |
+
x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
|
38 |
+
x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
|
39 |
+
|
40 |
+
coords_A, coords_B = warp[...,:2], warp[...,2:]
|
41 |
+
for i, x in enumerate(np.linspace(0,2*np.pi,200)):
|
42 |
+
t = (1 + np.cos(x))/2
|
43 |
+
interp_warp = (1-t)*coords_A + t*coords_B
|
44 |
+
im2_transfer_rgb = F.grid_sample(
|
45 |
+
x2[None], interp_warp[None], mode="bilinear", align_corners=False
|
46 |
+
)[0]
|
47 |
+
tensor_to_pil(im2_transfer_rgb, unnormalize=False).save(f"{save_path}_{i:03d}.jpg")
|
submodules/RoMa/demo/demo_fundamental.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import torch
|
3 |
+
import cv2
|
4 |
+
from romatch import roma_outdoor
|
5 |
+
|
6 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
7 |
+
if torch.backends.mps.is_available():
|
8 |
+
device = torch.device('mps')
|
9 |
+
|
10 |
+
if __name__ == "__main__":
|
11 |
+
from argparse import ArgumentParser
|
12 |
+
parser = ArgumentParser()
|
13 |
+
parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
|
14 |
+
parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
|
15 |
+
|
16 |
+
args, _ = parser.parse_known_args()
|
17 |
+
im1_path = args.im_A_path
|
18 |
+
im2_path = args.im_B_path
|
19 |
+
|
20 |
+
# Create model
|
21 |
+
roma_model = roma_outdoor(device=device)
|
22 |
+
|
23 |
+
|
24 |
+
W_A, H_A = Image.open(im1_path).size
|
25 |
+
W_B, H_B = Image.open(im2_path).size
|
26 |
+
|
27 |
+
# Match
|
28 |
+
warp, certainty = roma_model.match(im1_path, im2_path, device=device)
|
29 |
+
# Sample matches for estimation
|
30 |
+
matches, certainty = roma_model.sample(warp, certainty)
|
31 |
+
kpts1, kpts2 = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
|
32 |
+
F, mask = cv2.findFundamentalMat(
|
33 |
+
kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
|
34 |
+
)
|
submodules/RoMa/demo/demo_match.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import numpy as np
|
7 |
+
from romatch.utils.utils import tensor_to_pil
|
8 |
+
|
9 |
+
from romatch import roma_outdoor
|
10 |
+
|
11 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
12 |
+
if torch.backends.mps.is_available():
|
13 |
+
device = torch.device('mps')
|
14 |
+
|
15 |
+
if __name__ == "__main__":
|
16 |
+
from argparse import ArgumentParser
|
17 |
+
parser = ArgumentParser()
|
18 |
+
parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
|
19 |
+
parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
|
20 |
+
parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
|
21 |
+
|
22 |
+
args, _ = parser.parse_known_args()
|
23 |
+
im1_path = args.im_A_path
|
24 |
+
im2_path = args.im_B_path
|
25 |
+
save_path = args.save_path
|
26 |
+
|
27 |
+
# Create model
|
28 |
+
roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
|
29 |
+
|
30 |
+
H, W = roma_model.get_output_resolution()
|
31 |
+
|
32 |
+
im1 = Image.open(im1_path).resize((W, H))
|
33 |
+
im2 = Image.open(im2_path).resize((W, H))
|
34 |
+
|
35 |
+
# Match
|
36 |
+
warp, certainty = roma_model.match(im1_path, im2_path, device=device)
|
37 |
+
# Sampling not needed, but can be done with model.sample(warp, certainty)
|
38 |
+
x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
|
39 |
+
x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
|
40 |
+
|
41 |
+
im2_transfer_rgb = F.grid_sample(
|
42 |
+
x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
|
43 |
+
)[0]
|
44 |
+
im1_transfer_rgb = F.grid_sample(
|
45 |
+
x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
|
46 |
+
)[0]
|
47 |
+
warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2)
|
48 |
+
white_im = torch.ones((H,2*W),device=device)
|
49 |
+
vis_im = certainty * warp_im + (1 - certainty) * white_im
|
50 |
+
tensor_to_pil(vis_im, unnormalize=False).save(save_path)
|
submodules/RoMa/demo/demo_match_opencv_sift.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import cv2 as cv
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
if __name__ == "__main__":
|
11 |
+
from argparse import ArgumentParser
|
12 |
+
parser = ArgumentParser()
|
13 |
+
parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
|
14 |
+
parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
|
15 |
+
parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
|
16 |
+
|
17 |
+
args, _ = parser.parse_known_args()
|
18 |
+
im1_path = args.im_A_path
|
19 |
+
im2_path = args.im_B_path
|
20 |
+
save_path = args.save_path
|
21 |
+
|
22 |
+
img1 = cv.imread(im1_path,cv.IMREAD_GRAYSCALE) # queryImage
|
23 |
+
img2 = cv.imread(im2_path,cv.IMREAD_GRAYSCALE) # trainImage
|
24 |
+
# Initiate SIFT detector
|
25 |
+
sift = cv.SIFT_create()
|
26 |
+
# find the keypoints and descriptors with SIFT
|
27 |
+
kp1, des1 = sift.detectAndCompute(img1,None)
|
28 |
+
kp2, des2 = sift.detectAndCompute(img2,None)
|
29 |
+
# BFMatcher with default params
|
30 |
+
bf = cv.BFMatcher()
|
31 |
+
matches = bf.knnMatch(des1,des2,k=2)
|
32 |
+
# Apply ratio test
|
33 |
+
good = []
|
34 |
+
for m,n in matches:
|
35 |
+
if m.distance < 0.75*n.distance:
|
36 |
+
good.append([m])
|
37 |
+
# cv.drawMatchesKnn expects list of lists as matches.
|
38 |
+
draw_params = dict(matchColor = (255,0,0), # draw matches in red color
|
39 |
+
singlePointColor = None,
|
40 |
+
flags = 2)
|
41 |
+
|
42 |
+
img3 = cv.drawMatchesKnn(img1,kp1,img2,kp2,good,None,**draw_params)
|
43 |
+
Image.fromarray(img3).save("demo/sift_matches.png")
|
submodules/RoMa/demo/demo_match_tiny.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import numpy as np
|
7 |
+
from romatch.utils.utils import tensor_to_pil
|
8 |
+
|
9 |
+
from romatch import tiny_roma_v1_outdoor
|
10 |
+
|
11 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
12 |
+
if torch.backends.mps.is_available():
|
13 |
+
device = torch.device('mps')
|
14 |
+
|
15 |
+
if __name__ == "__main__":
|
16 |
+
from argparse import ArgumentParser
|
17 |
+
parser = ArgumentParser()
|
18 |
+
parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
|
19 |
+
parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
|
20 |
+
parser.add_argument("--save_A_path", default="demo/tiny_roma_warp_A.jpg", type=str)
|
21 |
+
parser.add_argument("--save_B_path", default="demo/tiny_roma_warp_B.jpg", type=str)
|
22 |
+
|
23 |
+
args, _ = parser.parse_known_args()
|
24 |
+
im1_path = args.im_A_path
|
25 |
+
im2_path = args.im_B_path
|
26 |
+
|
27 |
+
# Create model
|
28 |
+
roma_model = tiny_roma_v1_outdoor(device=device)
|
29 |
+
|
30 |
+
# Match
|
31 |
+
warp, certainty1 = roma_model.match(im1_path, im2_path)
|
32 |
+
|
33 |
+
h1, w1 = warp.shape[:2]
|
34 |
+
|
35 |
+
# maybe im1.size != im2.size
|
36 |
+
im1 = Image.open(im1_path).resize((w1, h1))
|
37 |
+
im2 = Image.open(im2_path)
|
38 |
+
x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
|
39 |
+
x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
|
40 |
+
|
41 |
+
h2, w2 = x2.shape[1:]
|
42 |
+
g1_p2x = w2 / 2 * (warp[..., 2] + 1)
|
43 |
+
g1_p2y = h2 / 2 * (warp[..., 3] + 1)
|
44 |
+
g2_p1x = torch.zeros((h2, w2), dtype=torch.float32).to(device) - 2
|
45 |
+
g2_p1y = torch.zeros((h2, w2), dtype=torch.float32).to(device) - 2
|
46 |
+
|
47 |
+
x, y = torch.meshgrid(
|
48 |
+
torch.arange(w1, device=device),
|
49 |
+
torch.arange(h1, device=device),
|
50 |
+
indexing="xy",
|
51 |
+
)
|
52 |
+
g2x = torch.round(g1_p2x[y, x]).long()
|
53 |
+
g2y = torch.round(g1_p2y[y, x]).long()
|
54 |
+
idx_x = torch.bitwise_and(0 <= g2x, g2x < w2)
|
55 |
+
idx_y = torch.bitwise_and(0 <= g2y, g2y < h2)
|
56 |
+
idx = torch.bitwise_and(idx_x, idx_y)
|
57 |
+
g2_p1x[g2y[idx], g2x[idx]] = x[idx].float() * 2 / w1 - 1
|
58 |
+
g2_p1y[g2y[idx], g2x[idx]] = y[idx].float() * 2 / h1 - 1
|
59 |
+
|
60 |
+
certainty2 = F.grid_sample(
|
61 |
+
certainty1[None][None],
|
62 |
+
torch.stack([g2_p1x, g2_p1y], dim=2)[None],
|
63 |
+
mode="bilinear",
|
64 |
+
align_corners=False,
|
65 |
+
)[0]
|
66 |
+
|
67 |
+
white_im1 = torch.ones((h1, w1), device = device)
|
68 |
+
white_im2 = torch.ones((h2, w2), device = device)
|
69 |
+
|
70 |
+
certainty1 = F.avg_pool2d(certainty1[None], kernel_size=5, stride=1, padding=2)[0]
|
71 |
+
certainty2 = F.avg_pool2d(certainty2[None], kernel_size=5, stride=1, padding=2)[0]
|
72 |
+
|
73 |
+
vis_im1 = certainty1 * x1 + (1 - certainty1) * white_im1
|
74 |
+
vis_im2 = certainty2 * x2 + (1 - certainty2) * white_im2
|
75 |
+
|
76 |
+
tensor_to_pil(vis_im1, unnormalize=False).save(args.save_A_path)
|
77 |
+
tensor_to_pil(vis_im2, unnormalize=False).save(args.save_B_path)
|
submodules/RoMa/demo/gif/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
submodules/RoMa/experiments/eval_roma_outdoor.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
from romatch.benchmarks import MegadepthDenseBenchmark
|
4 |
+
from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, HpatchesHomogBenchmark
|
5 |
+
from romatch.benchmarks import Mega1500PoseLibBenchmark
|
6 |
+
|
7 |
+
def test_mega_8_scenes(model, name):
|
8 |
+
mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
|
9 |
+
scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
|
10 |
+
'mega_8_scenes_0025_0.1_0.3.npz',
|
11 |
+
'mega_8_scenes_0021_0.1_0.3.npz',
|
12 |
+
'mega_8_scenes_0008_0.1_0.3.npz',
|
13 |
+
'mega_8_scenes_0032_0.1_0.3.npz',
|
14 |
+
'mega_8_scenes_1589_0.1_0.3.npz',
|
15 |
+
'mega_8_scenes_0063_0.1_0.3.npz',
|
16 |
+
'mega_8_scenes_0024_0.1_0.3.npz',
|
17 |
+
'mega_8_scenes_0019_0.3_0.5.npz',
|
18 |
+
'mega_8_scenes_0025_0.3_0.5.npz',
|
19 |
+
'mega_8_scenes_0021_0.3_0.5.npz',
|
20 |
+
'mega_8_scenes_0008_0.3_0.5.npz',
|
21 |
+
'mega_8_scenes_0032_0.3_0.5.npz',
|
22 |
+
'mega_8_scenes_1589_0.3_0.5.npz',
|
23 |
+
'mega_8_scenes_0063_0.3_0.5.npz',
|
24 |
+
'mega_8_scenes_0024_0.3_0.5.npz'])
|
25 |
+
mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
|
26 |
+
print(mega_8_scenes_results)
|
27 |
+
json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
|
28 |
+
|
29 |
+
def test_mega1500(model, name):
|
30 |
+
mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
|
31 |
+
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
|
32 |
+
json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
|
33 |
+
|
34 |
+
def test_mega1500_poselib(model, name):
|
35 |
+
mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth")
|
36 |
+
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
|
37 |
+
json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
|
38 |
+
|
39 |
+
def test_mega_dense(model, name):
|
40 |
+
megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000)
|
41 |
+
megadense_results = megadense_benchmark.benchmark(model)
|
42 |
+
json.dump(megadense_results, open(f"results/mega_dense_{name}.json", "w"))
|
43 |
+
|
44 |
+
def test_hpatches(model, name):
|
45 |
+
hpatches_benchmark = HpatchesHomogBenchmark("data/hpatches")
|
46 |
+
hpatches_results = hpatches_benchmark.benchmark(model)
|
47 |
+
json.dump(hpatches_results, open(f"results/hpatches_{name}.json", "w"))
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
from romatch import roma_outdoor
|
52 |
+
device = "cuda"
|
53 |
+
model = roma_outdoor(device = device, coarse_res = 672, upsample_res = 1344)
|
54 |
+
experiment_name = "roma_latest"
|
55 |
+
test_mega1500(model, experiment_name)
|
56 |
+
#test_mega1500_poselib(model, experiment_name)
|
57 |
+
|
submodules/RoMa/experiments/eval_tiny_roma_v1_outdoor.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
import json
|
5 |
+
from romatch.benchmarks import ScanNetBenchmark
|
6 |
+
from romatch.benchmarks import Mega1500PoseLibBenchmark, ScanNetPoselibBenchmark
|
7 |
+
from romatch.benchmarks import MegaDepthPoseEstimationBenchmark
|
8 |
+
|
9 |
+
def test_mega_8_scenes(model, name):
|
10 |
+
mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
|
11 |
+
scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
|
12 |
+
'mega_8_scenes_0025_0.1_0.3.npz',
|
13 |
+
'mega_8_scenes_0021_0.1_0.3.npz',
|
14 |
+
'mega_8_scenes_0008_0.1_0.3.npz',
|
15 |
+
'mega_8_scenes_0032_0.1_0.3.npz',
|
16 |
+
'mega_8_scenes_1589_0.1_0.3.npz',
|
17 |
+
'mega_8_scenes_0063_0.1_0.3.npz',
|
18 |
+
'mega_8_scenes_0024_0.1_0.3.npz',
|
19 |
+
'mega_8_scenes_0019_0.3_0.5.npz',
|
20 |
+
'mega_8_scenes_0025_0.3_0.5.npz',
|
21 |
+
'mega_8_scenes_0021_0.3_0.5.npz',
|
22 |
+
'mega_8_scenes_0008_0.3_0.5.npz',
|
23 |
+
'mega_8_scenes_0032_0.3_0.5.npz',
|
24 |
+
'mega_8_scenes_1589_0.3_0.5.npz',
|
25 |
+
'mega_8_scenes_0063_0.3_0.5.npz',
|
26 |
+
'mega_8_scenes_0024_0.3_0.5.npz'])
|
27 |
+
mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
|
28 |
+
print(mega_8_scenes_results)
|
29 |
+
json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
|
30 |
+
|
31 |
+
def test_mega1500(model, name):
|
32 |
+
mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
|
33 |
+
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
|
34 |
+
json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
|
35 |
+
|
36 |
+
def test_mega1500_poselib(model, name):
|
37 |
+
#model.exact_softmax = True
|
38 |
+
mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1)
|
39 |
+
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
|
40 |
+
json.dump(mega1500_results, open(f"results/mega1500_poselib_{name}.json", "w"))
|
41 |
+
|
42 |
+
def test_mega_8_scenes_poselib(model, name):
|
43 |
+
mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1,
|
44 |
+
scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
|
45 |
+
'mega_8_scenes_0025_0.1_0.3.npz',
|
46 |
+
'mega_8_scenes_0021_0.1_0.3.npz',
|
47 |
+
'mega_8_scenes_0008_0.1_0.3.npz',
|
48 |
+
'mega_8_scenes_0032_0.1_0.3.npz',
|
49 |
+
'mega_8_scenes_1589_0.1_0.3.npz',
|
50 |
+
'mega_8_scenes_0063_0.1_0.3.npz',
|
51 |
+
'mega_8_scenes_0024_0.1_0.3.npz',
|
52 |
+
'mega_8_scenes_0019_0.3_0.5.npz',
|
53 |
+
'mega_8_scenes_0025_0.3_0.5.npz',
|
54 |
+
'mega_8_scenes_0021_0.3_0.5.npz',
|
55 |
+
'mega_8_scenes_0008_0.3_0.5.npz',
|
56 |
+
'mega_8_scenes_0032_0.3_0.5.npz',
|
57 |
+
'mega_8_scenes_1589_0.3_0.5.npz',
|
58 |
+
'mega_8_scenes_0063_0.3_0.5.npz',
|
59 |
+
'mega_8_scenes_0024_0.3_0.5.npz'])
|
60 |
+
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
|
61 |
+
json.dump(mega1500_results, open(f"results/mega_8_scenes_poselib_{name}.json", "w"))
|
62 |
+
|
63 |
+
def test_scannet_poselib(model, name):
|
64 |
+
scannet_benchmark = ScanNetPoselibBenchmark("data/scannet")
|
65 |
+
scannet_results = scannet_benchmark.benchmark(model)
|
66 |
+
json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
|
67 |
+
|
68 |
+
def test_scannet(model, name):
|
69 |
+
scannet_benchmark = ScanNetBenchmark("data/scannet")
|
70 |
+
scannet_results = scannet_benchmark.benchmark(model)
|
71 |
+
json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
|
75 |
+
os.environ["OMP_NUM_THREADS"] = "16"
|
76 |
+
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
77 |
+
from romatch import tiny_roma_v1_outdoor
|
78 |
+
|
79 |
+
experiment_name = Path(__file__).stem
|
80 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
81 |
+
model = tiny_roma_v1_outdoor(device)
|
82 |
+
#test_mega1500_poselib(model, experiment_name)
|
83 |
+
test_mega_8_scenes_poselib(model, experiment_name)
|
84 |
+
|
submodules/RoMa/experiments/roma_indoor.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from argparse import ArgumentParser
|
4 |
+
|
5 |
+
from torch import nn
|
6 |
+
from torch.utils.data import ConcatDataset
|
7 |
+
import torch.distributed as dist
|
8 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
9 |
+
|
10 |
+
import json
|
11 |
+
import wandb
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from romatch.benchmarks import MegadepthDenseBenchmark
|
15 |
+
from romatch.datasets.megadepth import MegadepthBuilder
|
16 |
+
from romatch.datasets.scannet import ScanNetBuilder
|
17 |
+
from romatch.losses.robust_loss import RobustLosses
|
18 |
+
from romatch.benchmarks import MegadepthDenseBenchmark, ScanNetBenchmark
|
19 |
+
from romatch.train.train import train_k_steps
|
20 |
+
from romatch.models.matcher import *
|
21 |
+
from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention
|
22 |
+
from romatch.models.encoders import *
|
23 |
+
from romatch.checkpointing import CheckPoint
|
24 |
+
|
25 |
+
resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6)}
|
26 |
+
|
27 |
+
def get_model(pretrained_backbone=True, resolution = "medium", **kwargs):
|
28 |
+
gp_dim = 512
|
29 |
+
feat_dim = 512
|
30 |
+
decoder_dim = gp_dim + feat_dim
|
31 |
+
cls_to_coord_res = 64
|
32 |
+
coordinate_decoder = TransformerDecoder(
|
33 |
+
nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]),
|
34 |
+
decoder_dim,
|
35 |
+
cls_to_coord_res**2 + 1,
|
36 |
+
is_classifier=True,
|
37 |
+
amp = True,
|
38 |
+
pos_enc = False,)
|
39 |
+
dw = True
|
40 |
+
hidden_blocks = 8
|
41 |
+
kernel_size = 5
|
42 |
+
displacement_emb = "linear"
|
43 |
+
disable_local_corr_grad = True
|
44 |
+
|
45 |
+
conv_refiner = nn.ModuleDict(
|
46 |
+
{
|
47 |
+
"16": ConvRefiner(
|
48 |
+
2 * 512+128+(2*7+1)**2,
|
49 |
+
2 * 512+128+(2*7+1)**2,
|
50 |
+
2 + 1,
|
51 |
+
kernel_size=kernel_size,
|
52 |
+
dw=dw,
|
53 |
+
hidden_blocks=hidden_blocks,
|
54 |
+
displacement_emb=displacement_emb,
|
55 |
+
displacement_emb_dim=128,
|
56 |
+
local_corr_radius = 7,
|
57 |
+
corr_in_other = True,
|
58 |
+
amp = True,
|
59 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
60 |
+
bn_momentum = 0.01,
|
61 |
+
),
|
62 |
+
"8": ConvRefiner(
|
63 |
+
2 * 512+64+(2*3+1)**2,
|
64 |
+
2 * 512+64+(2*3+1)**2,
|
65 |
+
2 + 1,
|
66 |
+
kernel_size=kernel_size,
|
67 |
+
dw=dw,
|
68 |
+
hidden_blocks=hidden_blocks,
|
69 |
+
displacement_emb=displacement_emb,
|
70 |
+
displacement_emb_dim=64,
|
71 |
+
local_corr_radius = 3,
|
72 |
+
corr_in_other = True,
|
73 |
+
amp = True,
|
74 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
75 |
+
bn_momentum = 0.01,
|
76 |
+
),
|
77 |
+
"4": ConvRefiner(
|
78 |
+
2 * 256+32+(2*2+1)**2,
|
79 |
+
2 * 256+32+(2*2+1)**2,
|
80 |
+
2 + 1,
|
81 |
+
kernel_size=kernel_size,
|
82 |
+
dw=dw,
|
83 |
+
hidden_blocks=hidden_blocks,
|
84 |
+
displacement_emb=displacement_emb,
|
85 |
+
displacement_emb_dim=32,
|
86 |
+
local_corr_radius = 2,
|
87 |
+
corr_in_other = True,
|
88 |
+
amp = True,
|
89 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
90 |
+
bn_momentum = 0.01,
|
91 |
+
),
|
92 |
+
"2": ConvRefiner(
|
93 |
+
2 * 64+16,
|
94 |
+
128+16,
|
95 |
+
2 + 1,
|
96 |
+
kernel_size=kernel_size,
|
97 |
+
dw=dw,
|
98 |
+
hidden_blocks=hidden_blocks,
|
99 |
+
displacement_emb=displacement_emb,
|
100 |
+
displacement_emb_dim=16,
|
101 |
+
amp = True,
|
102 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
103 |
+
bn_momentum = 0.01,
|
104 |
+
),
|
105 |
+
"1": ConvRefiner(
|
106 |
+
2 * 9 + 6,
|
107 |
+
24,
|
108 |
+
2 + 1,
|
109 |
+
kernel_size=kernel_size,
|
110 |
+
dw=dw,
|
111 |
+
hidden_blocks = hidden_blocks,
|
112 |
+
displacement_emb = displacement_emb,
|
113 |
+
displacement_emb_dim = 6,
|
114 |
+
amp = True,
|
115 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
116 |
+
bn_momentum = 0.01,
|
117 |
+
),
|
118 |
+
}
|
119 |
+
)
|
120 |
+
kernel_temperature = 0.2
|
121 |
+
learn_temperature = False
|
122 |
+
no_cov = True
|
123 |
+
kernel = CosKernel
|
124 |
+
only_attention = False
|
125 |
+
basis = "fourier"
|
126 |
+
gp16 = GP(
|
127 |
+
kernel,
|
128 |
+
T=kernel_temperature,
|
129 |
+
learn_temperature=learn_temperature,
|
130 |
+
only_attention=only_attention,
|
131 |
+
gp_dim=gp_dim,
|
132 |
+
basis=basis,
|
133 |
+
no_cov=no_cov,
|
134 |
+
)
|
135 |
+
gps = nn.ModuleDict({"16": gp16})
|
136 |
+
proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
|
137 |
+
proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
|
138 |
+
proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
|
139 |
+
proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
|
140 |
+
proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
|
141 |
+
proj = nn.ModuleDict({
|
142 |
+
"16": proj16,
|
143 |
+
"8": proj8,
|
144 |
+
"4": proj4,
|
145 |
+
"2": proj2,
|
146 |
+
"1": proj1,
|
147 |
+
})
|
148 |
+
displacement_dropout_p = 0.0
|
149 |
+
gm_warp_dropout_p = 0.0
|
150 |
+
decoder = Decoder(coordinate_decoder,
|
151 |
+
gps,
|
152 |
+
proj,
|
153 |
+
conv_refiner,
|
154 |
+
detach=True,
|
155 |
+
scales=["16", "8", "4", "2", "1"],
|
156 |
+
displacement_dropout_p = displacement_dropout_p,
|
157 |
+
gm_warp_dropout_p = gm_warp_dropout_p)
|
158 |
+
h,w = resolutions[resolution]
|
159 |
+
encoder = CNNandDinov2(
|
160 |
+
cnn_kwargs = dict(
|
161 |
+
pretrained=pretrained_backbone,
|
162 |
+
amp = True),
|
163 |
+
amp = True,
|
164 |
+
use_vgg = True,
|
165 |
+
)
|
166 |
+
matcher = RegressionMatcher(encoder, decoder, h=h, w=w, alpha=1, beta=0,**kwargs)
|
167 |
+
return matcher
|
168 |
+
|
169 |
+
def train(args):
|
170 |
+
dist.init_process_group('nccl')
|
171 |
+
#torch._dynamo.config.verbose=True
|
172 |
+
gpus = int(os.environ['WORLD_SIZE'])
|
173 |
+
# create model and move it to GPU with id rank
|
174 |
+
rank = dist.get_rank()
|
175 |
+
print(f"Start running DDP on rank {rank}")
|
176 |
+
device_id = rank % torch.cuda.device_count()
|
177 |
+
romatch.LOCAL_RANK = device_id
|
178 |
+
torch.cuda.set_device(device_id)
|
179 |
+
|
180 |
+
resolution = args.train_resolution
|
181 |
+
wandb_log = not args.dont_log_wandb
|
182 |
+
experiment_name = os.path.splitext(os.path.basename(__file__))[0]
|
183 |
+
wandb_mode = "online" if wandb_log and rank == 0 and False else "disabled"
|
184 |
+
wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
|
185 |
+
checkpoint_dir = "workspace/checkpoints/"
|
186 |
+
h,w = resolutions[resolution]
|
187 |
+
model = get_model(pretrained_backbone=True, resolution=resolution, attenuate_cert = False).to(device_id)
|
188 |
+
# Num steps
|
189 |
+
global_step = 0
|
190 |
+
batch_size = args.gpu_batch_size
|
191 |
+
step_size = gpus*batch_size
|
192 |
+
romatch.STEP_SIZE = step_size
|
193 |
+
|
194 |
+
N = (32 * 250000) # 250k steps of batch size 32
|
195 |
+
# checkpoint every
|
196 |
+
k = 25000 // romatch.STEP_SIZE
|
197 |
+
|
198 |
+
# Data
|
199 |
+
mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
|
200 |
+
use_horizontal_flip_aug = True
|
201 |
+
rot_prob = 0
|
202 |
+
depth_interpolation_mode = "bilinear"
|
203 |
+
megadepth_train1 = mega.build_scenes(
|
204 |
+
split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
|
205 |
+
ht=h,wt=w,
|
206 |
+
)
|
207 |
+
megadepth_train2 = mega.build_scenes(
|
208 |
+
split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
|
209 |
+
ht=h,wt=w,
|
210 |
+
)
|
211 |
+
megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
|
212 |
+
mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
|
213 |
+
|
214 |
+
scannet = ScanNetBuilder(data_root="data/scannet")
|
215 |
+
scannet_train = scannet.build_scenes(split="train", ht=h, wt=w, use_horizontal_flip_aug = use_horizontal_flip_aug)
|
216 |
+
scannet_train = ConcatDataset(scannet_train)
|
217 |
+
scannet_ws = scannet.weight_scenes(scannet_train, alpha=0.75)
|
218 |
+
|
219 |
+
# Loss and optimizer
|
220 |
+
depth_loss_scannet = RobustLosses(
|
221 |
+
ce_weight=0.0,
|
222 |
+
local_dist={1:4, 2:4, 4:8, 8:8},
|
223 |
+
local_largest_scale=8,
|
224 |
+
depth_interpolation_mode=depth_interpolation_mode,
|
225 |
+
alpha = 0.5,
|
226 |
+
c = 1e-4,)
|
227 |
+
# Loss and optimizer
|
228 |
+
depth_loss_mega = RobustLosses(
|
229 |
+
ce_weight=0.01,
|
230 |
+
local_dist={1:4, 2:4, 4:8, 8:8},
|
231 |
+
local_largest_scale=8,
|
232 |
+
depth_interpolation_mode=depth_interpolation_mode,
|
233 |
+
alpha = 0.5,
|
234 |
+
c = 1e-4,)
|
235 |
+
parameters = [
|
236 |
+
{"params": model.encoder.parameters(), "lr": romatch.STEP_SIZE * 5e-6 / 8},
|
237 |
+
{"params": model.decoder.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8},
|
238 |
+
]
|
239 |
+
optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
|
240 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
241 |
+
optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10])
|
242 |
+
megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
|
243 |
+
checkpointer = CheckPoint(checkpoint_dir, experiment_name)
|
244 |
+
model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
|
245 |
+
romatch.GLOBAL_STEP = global_step
|
246 |
+
ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters = False, gradient_as_bucket_view=True)
|
247 |
+
grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
|
248 |
+
grad_clip_norm = 0.01
|
249 |
+
for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE):
|
250 |
+
mega_sampler = torch.utils.data.WeightedRandomSampler(
|
251 |
+
mega_ws, num_samples = batch_size * k, replacement=False
|
252 |
+
)
|
253 |
+
mega_dataloader = iter(
|
254 |
+
torch.utils.data.DataLoader(
|
255 |
+
megadepth_train,
|
256 |
+
batch_size = batch_size,
|
257 |
+
sampler = mega_sampler,
|
258 |
+
num_workers = 8,
|
259 |
+
)
|
260 |
+
)
|
261 |
+
scannet_ws_sampler = torch.utils.data.WeightedRandomSampler(
|
262 |
+
scannet_ws, num_samples=batch_size * k, replacement=False
|
263 |
+
)
|
264 |
+
scannet_dataloader = iter(
|
265 |
+
torch.utils.data.DataLoader(
|
266 |
+
scannet_train,
|
267 |
+
batch_size=batch_size,
|
268 |
+
sampler=scannet_ws_sampler,
|
269 |
+
num_workers=gpus * 8,
|
270 |
+
)
|
271 |
+
)
|
272 |
+
for n_k in tqdm(range(n, n + 2 * k, 2),disable = romatch.RANK > 0):
|
273 |
+
train_k_steps(
|
274 |
+
n_k, 1, mega_dataloader, ddp_model, depth_loss_mega, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False
|
275 |
+
)
|
276 |
+
train_k_steps(
|
277 |
+
n_k + 1, 1, scannet_dataloader, ddp_model, depth_loss_scannet, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False
|
278 |
+
)
|
279 |
+
checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP)
|
280 |
+
wandb.log(megadense_benchmark.benchmark(model), step = romatch.GLOBAL_STEP)
|
281 |
+
|
282 |
+
def test_scannet(model, name, resolution, sample_mode):
|
283 |
+
scannet_benchmark = ScanNetBenchmark("data/scannet")
|
284 |
+
scannet_results = scannet_benchmark.benchmark(model)
|
285 |
+
json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
|
286 |
+
|
287 |
+
if __name__ == "__main__":
|
288 |
+
import warnings
|
289 |
+
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
290 |
+
warnings.filterwarnings('ignore')#, category=UserWarning)#, message='WARNING batched routines are designed for small sizes.')
|
291 |
+
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
|
292 |
+
os.environ["OMP_NUM_THREADS"] = "16"
|
293 |
+
|
294 |
+
import romatch
|
295 |
+
parser = ArgumentParser()
|
296 |
+
parser.add_argument("--test", action='store_true')
|
297 |
+
parser.add_argument("--debug_mode", action='store_true')
|
298 |
+
parser.add_argument("--dont_log_wandb", action='store_true')
|
299 |
+
parser.add_argument("--train_resolution", default='medium')
|
300 |
+
parser.add_argument("--gpu_batch_size", default=4, type=int)
|
301 |
+
parser.add_argument("--wandb_entity", required = False)
|
302 |
+
|
303 |
+
args, _ = parser.parse_known_args()
|
304 |
+
romatch.DEBUG_MODE = args.debug_mode
|
305 |
+
if not args.test:
|
306 |
+
train(args)
|
307 |
+
experiment_name = os.path.splitext(os.path.basename(__file__))[0]
|
308 |
+
checkpoint_dir = "workspace/"
|
309 |
+
checkpoint_name = checkpoint_dir + experiment_name + ".pth"
|
310 |
+
test_resolution = "medium"
|
311 |
+
sample_mode = "threshold_balanced"
|
312 |
+
symmetric = True
|
313 |
+
upsample_preds = False
|
314 |
+
attenuate_cert = True
|
315 |
+
|
316 |
+
model = get_model(pretrained_backbone=False, resolution = test_resolution, sample_mode = sample_mode, upsample_preds = upsample_preds, symmetric=symmetric, name=experiment_name, attenuate_cert = attenuate_cert)
|
317 |
+
model = model.cuda()
|
318 |
+
states = torch.load(checkpoint_name)
|
319 |
+
model.load_state_dict(states["model"])
|
320 |
+
test_scannet(model, experiment_name, resolution = test_resolution, sample_mode = sample_mode)
|
submodules/RoMa/experiments/train_roma_outdoor.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from argparse import ArgumentParser
|
4 |
+
|
5 |
+
from torch import nn
|
6 |
+
from torch.utils.data import ConcatDataset
|
7 |
+
import torch.distributed as dist
|
8 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
9 |
+
import json
|
10 |
+
import wandb
|
11 |
+
|
12 |
+
from romatch.benchmarks import MegadepthDenseBenchmark
|
13 |
+
from romatch.datasets.megadepth import MegadepthBuilder
|
14 |
+
from romatch.losses.robust_loss import RobustLosses
|
15 |
+
from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, MegadepthDenseBenchmark, HpatchesHomogBenchmark
|
16 |
+
|
17 |
+
from romatch.train.train import train_k_steps
|
18 |
+
from romatch.models.matcher import *
|
19 |
+
from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention
|
20 |
+
from romatch.models.encoders import *
|
21 |
+
from romatch.checkpointing import CheckPoint
|
22 |
+
|
23 |
+
resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6)}
|
24 |
+
|
25 |
+
def get_model(pretrained_backbone=True, resolution = "medium", **kwargs):
|
26 |
+
import warnings
|
27 |
+
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
28 |
+
gp_dim = 512
|
29 |
+
feat_dim = 512
|
30 |
+
decoder_dim = gp_dim + feat_dim
|
31 |
+
cls_to_coord_res = 64
|
32 |
+
coordinate_decoder = TransformerDecoder(
|
33 |
+
nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]),
|
34 |
+
decoder_dim,
|
35 |
+
cls_to_coord_res**2 + 1,
|
36 |
+
is_classifier=True,
|
37 |
+
amp = True,
|
38 |
+
pos_enc = False,)
|
39 |
+
dw = True
|
40 |
+
hidden_blocks = 8
|
41 |
+
kernel_size = 5
|
42 |
+
displacement_emb = "linear"
|
43 |
+
disable_local_corr_grad = True
|
44 |
+
|
45 |
+
conv_refiner = nn.ModuleDict(
|
46 |
+
{
|
47 |
+
"16": ConvRefiner(
|
48 |
+
2 * 512+128+(2*7+1)**2,
|
49 |
+
2 * 512+128+(2*7+1)**2,
|
50 |
+
2 + 1,
|
51 |
+
kernel_size=kernel_size,
|
52 |
+
dw=dw,
|
53 |
+
hidden_blocks=hidden_blocks,
|
54 |
+
displacement_emb=displacement_emb,
|
55 |
+
displacement_emb_dim=128,
|
56 |
+
local_corr_radius = 7,
|
57 |
+
corr_in_other = True,
|
58 |
+
amp = True,
|
59 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
60 |
+
bn_momentum = 0.01,
|
61 |
+
),
|
62 |
+
"8": ConvRefiner(
|
63 |
+
2 * 512+64+(2*3+1)**2,
|
64 |
+
2 * 512+64+(2*3+1)**2,
|
65 |
+
2 + 1,
|
66 |
+
kernel_size=kernel_size,
|
67 |
+
dw=dw,
|
68 |
+
hidden_blocks=hidden_blocks,
|
69 |
+
displacement_emb=displacement_emb,
|
70 |
+
displacement_emb_dim=64,
|
71 |
+
local_corr_radius = 3,
|
72 |
+
corr_in_other = True,
|
73 |
+
amp = True,
|
74 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
75 |
+
bn_momentum = 0.01,
|
76 |
+
),
|
77 |
+
"4": ConvRefiner(
|
78 |
+
2 * 256+32+(2*2+1)**2,
|
79 |
+
2 * 256+32+(2*2+1)**2,
|
80 |
+
2 + 1,
|
81 |
+
kernel_size=kernel_size,
|
82 |
+
dw=dw,
|
83 |
+
hidden_blocks=hidden_blocks,
|
84 |
+
displacement_emb=displacement_emb,
|
85 |
+
displacement_emb_dim=32,
|
86 |
+
local_corr_radius = 2,
|
87 |
+
corr_in_other = True,
|
88 |
+
amp = True,
|
89 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
90 |
+
bn_momentum = 0.01,
|
91 |
+
),
|
92 |
+
"2": ConvRefiner(
|
93 |
+
2 * 64+16,
|
94 |
+
128+16,
|
95 |
+
2 + 1,
|
96 |
+
kernel_size=kernel_size,
|
97 |
+
dw=dw,
|
98 |
+
hidden_blocks=hidden_blocks,
|
99 |
+
displacement_emb=displacement_emb,
|
100 |
+
displacement_emb_dim=16,
|
101 |
+
amp = True,
|
102 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
103 |
+
bn_momentum = 0.01,
|
104 |
+
),
|
105 |
+
"1": ConvRefiner(
|
106 |
+
2 * 9 + 6,
|
107 |
+
24,
|
108 |
+
2 + 1,
|
109 |
+
kernel_size=kernel_size,
|
110 |
+
dw=dw,
|
111 |
+
hidden_blocks = hidden_blocks,
|
112 |
+
displacement_emb = displacement_emb,
|
113 |
+
displacement_emb_dim = 6,
|
114 |
+
amp = True,
|
115 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
116 |
+
bn_momentum = 0.01,
|
117 |
+
),
|
118 |
+
}
|
119 |
+
)
|
120 |
+
kernel_temperature = 0.2
|
121 |
+
learn_temperature = False
|
122 |
+
no_cov = True
|
123 |
+
kernel = CosKernel
|
124 |
+
only_attention = False
|
125 |
+
basis = "fourier"
|
126 |
+
gp16 = GP(
|
127 |
+
kernel,
|
128 |
+
T=kernel_temperature,
|
129 |
+
learn_temperature=learn_temperature,
|
130 |
+
only_attention=only_attention,
|
131 |
+
gp_dim=gp_dim,
|
132 |
+
basis=basis,
|
133 |
+
no_cov=no_cov,
|
134 |
+
)
|
135 |
+
gps = nn.ModuleDict({"16": gp16})
|
136 |
+
proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
|
137 |
+
proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
|
138 |
+
proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
|
139 |
+
proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
|
140 |
+
proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
|
141 |
+
proj = nn.ModuleDict({
|
142 |
+
"16": proj16,
|
143 |
+
"8": proj8,
|
144 |
+
"4": proj4,
|
145 |
+
"2": proj2,
|
146 |
+
"1": proj1,
|
147 |
+
})
|
148 |
+
displacement_dropout_p = 0.0
|
149 |
+
gm_warp_dropout_p = 0.0
|
150 |
+
decoder = Decoder(coordinate_decoder,
|
151 |
+
gps,
|
152 |
+
proj,
|
153 |
+
conv_refiner,
|
154 |
+
detach=True,
|
155 |
+
scales=["16", "8", "4", "2", "1"],
|
156 |
+
displacement_dropout_p = displacement_dropout_p,
|
157 |
+
gm_warp_dropout_p = gm_warp_dropout_p)
|
158 |
+
h,w = resolutions[resolution]
|
159 |
+
encoder = CNNandDinov2(
|
160 |
+
cnn_kwargs = dict(
|
161 |
+
pretrained=pretrained_backbone,
|
162 |
+
amp = True),
|
163 |
+
amp = True,
|
164 |
+
use_vgg = True,
|
165 |
+
)
|
166 |
+
matcher = RegressionMatcher(encoder, decoder, h=h, w=w,**kwargs)
|
167 |
+
return matcher
|
168 |
+
|
169 |
+
def train(args):
|
170 |
+
dist.init_process_group('nccl')
|
171 |
+
#torch._dynamo.config.verbose=True
|
172 |
+
gpus = int(os.environ['WORLD_SIZE'])
|
173 |
+
# create model and move it to GPU with id rank
|
174 |
+
rank = dist.get_rank()
|
175 |
+
print(f"Start running DDP on rank {rank}")
|
176 |
+
device_id = rank % torch.cuda.device_count()
|
177 |
+
romatch.LOCAL_RANK = device_id
|
178 |
+
torch.cuda.set_device(device_id)
|
179 |
+
|
180 |
+
resolution = args.train_resolution
|
181 |
+
wandb_log = not args.dont_log_wandb
|
182 |
+
experiment_name = os.path.splitext(os.path.basename(__file__))[0]
|
183 |
+
wandb_mode = "online" if wandb_log and rank == 0 else "disabled"
|
184 |
+
wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
|
185 |
+
checkpoint_dir = "workspace/checkpoints/"
|
186 |
+
h,w = resolutions[resolution]
|
187 |
+
model = get_model(pretrained_backbone=True, resolution=resolution, attenuate_cert = False).to(device_id)
|
188 |
+
# Num steps
|
189 |
+
global_step = 0
|
190 |
+
batch_size = args.gpu_batch_size
|
191 |
+
step_size = gpus*batch_size
|
192 |
+
romatch.STEP_SIZE = step_size
|
193 |
+
|
194 |
+
N = (32 * 250000) # 250k steps of batch size 32
|
195 |
+
# checkpoint every
|
196 |
+
k = 25000 // romatch.STEP_SIZE
|
197 |
+
|
198 |
+
# Data
|
199 |
+
mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
|
200 |
+
use_horizontal_flip_aug = True
|
201 |
+
rot_prob = 0
|
202 |
+
depth_interpolation_mode = "bilinear"
|
203 |
+
megadepth_train1 = mega.build_scenes(
|
204 |
+
split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
|
205 |
+
ht=h,wt=w,
|
206 |
+
)
|
207 |
+
megadepth_train2 = mega.build_scenes(
|
208 |
+
split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
|
209 |
+
ht=h,wt=w,
|
210 |
+
)
|
211 |
+
megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
|
212 |
+
mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
|
213 |
+
# Loss and optimizer
|
214 |
+
depth_loss = RobustLosses(
|
215 |
+
ce_weight=0.01,
|
216 |
+
local_dist={1:4, 2:4, 4:8, 8:8},
|
217 |
+
local_largest_scale=8,
|
218 |
+
depth_interpolation_mode=depth_interpolation_mode,
|
219 |
+
alpha = 0.5,
|
220 |
+
c = 1e-4,)
|
221 |
+
parameters = [
|
222 |
+
{"params": model.encoder.parameters(), "lr": romatch.STEP_SIZE * 5e-6 / 8},
|
223 |
+
{"params": model.decoder.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8},
|
224 |
+
]
|
225 |
+
optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
|
226 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
227 |
+
optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10])
|
228 |
+
megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
|
229 |
+
checkpointer = CheckPoint(checkpoint_dir, experiment_name)
|
230 |
+
model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
|
231 |
+
romatch.GLOBAL_STEP = global_step
|
232 |
+
ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters = False, gradient_as_bucket_view=True)
|
233 |
+
grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
|
234 |
+
grad_clip_norm = 0.01
|
235 |
+
for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE):
|
236 |
+
mega_sampler = torch.utils.data.WeightedRandomSampler(
|
237 |
+
mega_ws, num_samples = batch_size * k, replacement=False
|
238 |
+
)
|
239 |
+
mega_dataloader = iter(
|
240 |
+
torch.utils.data.DataLoader(
|
241 |
+
megadepth_train,
|
242 |
+
batch_size = batch_size,
|
243 |
+
sampler = mega_sampler,
|
244 |
+
num_workers = 8,
|
245 |
+
)
|
246 |
+
)
|
247 |
+
train_k_steps(
|
248 |
+
n, k, mega_dataloader, ddp_model, depth_loss, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm,
|
249 |
+
)
|
250 |
+
checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP)
|
251 |
+
wandb.log(megadense_benchmark.benchmark(model), step = romatch.GLOBAL_STEP)
|
252 |
+
|
253 |
+
def test_mega_8_scenes(model, name):
|
254 |
+
mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
|
255 |
+
scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
|
256 |
+
'mega_8_scenes_0025_0.1_0.3.npz',
|
257 |
+
'mega_8_scenes_0021_0.1_0.3.npz',
|
258 |
+
'mega_8_scenes_0008_0.1_0.3.npz',
|
259 |
+
'mega_8_scenes_0032_0.1_0.3.npz',
|
260 |
+
'mega_8_scenes_1589_0.1_0.3.npz',
|
261 |
+
'mega_8_scenes_0063_0.1_0.3.npz',
|
262 |
+
'mega_8_scenes_0024_0.1_0.3.npz',
|
263 |
+
'mega_8_scenes_0019_0.3_0.5.npz',
|
264 |
+
'mega_8_scenes_0025_0.3_0.5.npz',
|
265 |
+
'mega_8_scenes_0021_0.3_0.5.npz',
|
266 |
+
'mega_8_scenes_0008_0.3_0.5.npz',
|
267 |
+
'mega_8_scenes_0032_0.3_0.5.npz',
|
268 |
+
'mega_8_scenes_1589_0.3_0.5.npz',
|
269 |
+
'mega_8_scenes_0063_0.3_0.5.npz',
|
270 |
+
'mega_8_scenes_0024_0.3_0.5.npz'])
|
271 |
+
mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
|
272 |
+
print(mega_8_scenes_results)
|
273 |
+
json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
|
274 |
+
|
275 |
+
def test_mega1500(model, name):
|
276 |
+
mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
|
277 |
+
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
|
278 |
+
json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
|
279 |
+
|
280 |
+
def test_mega_dense(model, name):
|
281 |
+
megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000)
|
282 |
+
megadense_results = megadense_benchmark.benchmark(model)
|
283 |
+
json.dump(megadense_results, open(f"results/mega_dense_{name}.json", "w"))
|
284 |
+
|
285 |
+
def test_hpatches(model, name):
|
286 |
+
hpatches_benchmark = HpatchesHomogBenchmark("data/hpatches")
|
287 |
+
hpatches_results = hpatches_benchmark.benchmark(model)
|
288 |
+
json.dump(hpatches_results, open(f"results/hpatches_{name}.json", "w"))
|
289 |
+
|
290 |
+
|
291 |
+
if __name__ == "__main__":
|
292 |
+
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
|
293 |
+
os.environ["OMP_NUM_THREADS"] = "16"
|
294 |
+
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
295 |
+
import romatch
|
296 |
+
parser = ArgumentParser()
|
297 |
+
parser.add_argument("--only_test", action='store_true')
|
298 |
+
parser.add_argument("--debug_mode", action='store_true')
|
299 |
+
parser.add_argument("--dont_log_wandb", action='store_true')
|
300 |
+
parser.add_argument("--train_resolution", default='medium')
|
301 |
+
parser.add_argument("--gpu_batch_size", default=8, type=int)
|
302 |
+
parser.add_argument("--wandb_entity", required = False)
|
303 |
+
|
304 |
+
args, _ = parser.parse_known_args()
|
305 |
+
romatch.DEBUG_MODE = args.debug_mode
|
306 |
+
if not args.only_test:
|
307 |
+
train(args)
|
submodules/RoMa/experiments/train_tiny_roma_v1_outdoor.py
ADDED
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
from pathlib import Path
|
8 |
+
import math
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from torch import nn
|
12 |
+
from torch.utils.data import ConcatDataset
|
13 |
+
import torch.distributed as dist
|
14 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
15 |
+
import json
|
16 |
+
import wandb
|
17 |
+
from PIL import Image
|
18 |
+
from torchvision.transforms import ToTensor
|
19 |
+
|
20 |
+
from romatch.benchmarks import MegadepthDenseBenchmark, ScanNetBenchmark
|
21 |
+
from romatch.benchmarks import Mega1500PoseLibBenchmark, ScanNetPoselibBenchmark
|
22 |
+
from romatch.datasets.megadepth import MegadepthBuilder
|
23 |
+
from romatch.losses.robust_loss_tiny_roma import RobustLosses
|
24 |
+
from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, MegadepthDenseBenchmark, HpatchesHomogBenchmark
|
25 |
+
from romatch.train.train import train_k_steps
|
26 |
+
from romatch.checkpointing import CheckPoint
|
27 |
+
|
28 |
+
resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6), "xfeat": (600,800), "big": (768, 1024)}
|
29 |
+
|
30 |
+
def kde(x, std = 0.1):
|
31 |
+
# use a gaussian kernel to estimate density
|
32 |
+
x = x.half() # Do it in half precision TODO: remove hardcoding
|
33 |
+
scores = (-torch.cdist(x,x)**2/(2*std**2)).exp()
|
34 |
+
density = scores.sum(dim=-1)
|
35 |
+
return density
|
36 |
+
|
37 |
+
class BasicLayer(nn.Module):
|
38 |
+
"""
|
39 |
+
Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU
|
40 |
+
"""
|
41 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False, relu = True):
|
42 |
+
super().__init__()
|
43 |
+
self.layer = nn.Sequential(
|
44 |
+
nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias),
|
45 |
+
nn.BatchNorm2d(out_channels, affine=False),
|
46 |
+
nn.ReLU(inplace = True) if relu else nn.Identity()
|
47 |
+
)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
return self.layer(x)
|
51 |
+
|
52 |
+
class XFeatModel(nn.Module):
|
53 |
+
"""
|
54 |
+
Implementation of architecture described in
|
55 |
+
"XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self, xfeat = None,
|
59 |
+
freeze_xfeat = True,
|
60 |
+
sample_mode = "threshold_balanced",
|
61 |
+
symmetric = False,
|
62 |
+
exact_softmax = False):
|
63 |
+
super().__init__()
|
64 |
+
if xfeat is None:
|
65 |
+
xfeat = torch.hub.load('verlab/accelerated_features', 'XFeat', pretrained = True, top_k = 4096).net
|
66 |
+
del xfeat.heatmap_head, xfeat.keypoint_head, xfeat.fine_matcher
|
67 |
+
if freeze_xfeat:
|
68 |
+
xfeat.train(False)
|
69 |
+
self.xfeat = [xfeat]# hide params from ddp
|
70 |
+
else:
|
71 |
+
self.xfeat = nn.ModuleList([xfeat])
|
72 |
+
self.freeze_xfeat = freeze_xfeat
|
73 |
+
match_dim = 256
|
74 |
+
self.coarse_matcher = nn.Sequential(
|
75 |
+
BasicLayer(64+64+2, match_dim,),
|
76 |
+
BasicLayer(match_dim, match_dim,),
|
77 |
+
BasicLayer(match_dim, match_dim,),
|
78 |
+
BasicLayer(match_dim, match_dim,),
|
79 |
+
nn.Conv2d(match_dim, 3, kernel_size=1, bias=True, padding=0))
|
80 |
+
fine_match_dim = 64
|
81 |
+
self.fine_matcher = nn.Sequential(
|
82 |
+
BasicLayer(24+24+2, fine_match_dim,),
|
83 |
+
BasicLayer(fine_match_dim, fine_match_dim,),
|
84 |
+
BasicLayer(fine_match_dim, fine_match_dim,),
|
85 |
+
BasicLayer(fine_match_dim, fine_match_dim,),
|
86 |
+
nn.Conv2d(fine_match_dim, 3, kernel_size=1, bias=True, padding=0),)
|
87 |
+
self.sample_mode = sample_mode
|
88 |
+
self.sample_thresh = 0.2
|
89 |
+
self.symmetric = symmetric
|
90 |
+
self.exact_softmax = exact_softmax
|
91 |
+
|
92 |
+
@property
|
93 |
+
def device(self):
|
94 |
+
return self.fine_matcher[-1].weight.device
|
95 |
+
|
96 |
+
def preprocess_tensor(self, x):
|
97 |
+
""" Guarantee that image is divisible by 32 to avoid aliasing artifacts. """
|
98 |
+
H, W = x.shape[-2:]
|
99 |
+
_H, _W = (H//32) * 32, (W//32) * 32
|
100 |
+
rh, rw = H/_H, W/_W
|
101 |
+
|
102 |
+
x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False)
|
103 |
+
return x, rh, rw
|
104 |
+
|
105 |
+
def forward_single(self, x):
|
106 |
+
with torch.inference_mode(self.freeze_xfeat or not self.training):
|
107 |
+
xfeat = self.xfeat[0]
|
108 |
+
with torch.no_grad():
|
109 |
+
x = x.mean(dim=1, keepdim = True)
|
110 |
+
x = xfeat.norm(x)
|
111 |
+
|
112 |
+
#main backbone
|
113 |
+
x1 = xfeat.block1(x)
|
114 |
+
x2 = xfeat.block2(x1 + xfeat.skip1(x))
|
115 |
+
x3 = xfeat.block3(x2)
|
116 |
+
x4 = xfeat.block4(x3)
|
117 |
+
x5 = xfeat.block5(x4)
|
118 |
+
x4 = F.interpolate(x4, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
|
119 |
+
x5 = F.interpolate(x5, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
|
120 |
+
feats = xfeat.block_fusion( x3 + x4 + x5 )
|
121 |
+
if self.freeze_xfeat:
|
122 |
+
return x2.clone(), feats.clone()
|
123 |
+
return x2, feats
|
124 |
+
|
125 |
+
def to_pixel_coordinates(self, coords, H_A, W_A, H_B = None, W_B = None):
|
126 |
+
if coords.shape[-1] == 2:
|
127 |
+
return self._to_pixel_coordinates(coords, H_A, W_A)
|
128 |
+
|
129 |
+
if isinstance(coords, (list, tuple)):
|
130 |
+
kpts_A, kpts_B = coords[0], coords[1]
|
131 |
+
else:
|
132 |
+
kpts_A, kpts_B = coords[...,:2], coords[...,2:]
|
133 |
+
return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(kpts_B, H_B, W_B)
|
134 |
+
|
135 |
+
def _to_pixel_coordinates(self, coords, H, W):
|
136 |
+
kpts = torch.stack((W/2 * (coords[...,0]+1), H/2 * (coords[...,1]+1)),axis=-1)
|
137 |
+
return kpts
|
138 |
+
|
139 |
+
def pos_embed(self, corr_volume: torch.Tensor):
|
140 |
+
B, H1, W1, H0, W0 = corr_volume.shape
|
141 |
+
grid = torch.stack(
|
142 |
+
torch.meshgrid(
|
143 |
+
torch.linspace(-1+1/W1,1-1/W1, W1),
|
144 |
+
torch.linspace(-1+1/H1,1-1/H1, H1),
|
145 |
+
indexing = "xy"),
|
146 |
+
dim = -1).float().to(corr_volume).reshape(H1*W1, 2)
|
147 |
+
down = 4
|
148 |
+
if not self.training and not self.exact_softmax:
|
149 |
+
grid_lr = torch.stack(
|
150 |
+
torch.meshgrid(
|
151 |
+
torch.linspace(-1+down/W1,1-down/W1, W1//down),
|
152 |
+
torch.linspace(-1+down/H1,1-down/H1, H1//down),
|
153 |
+
indexing = "xy"),
|
154 |
+
dim = -1).float().to(corr_volume).reshape(H1*W1 //down**2, 2)
|
155 |
+
cv = corr_volume
|
156 |
+
best_match = cv.reshape(B,H1*W1,H0,W0).amax(dim=1) # B, HW, H, W
|
157 |
+
P_lowres = torch.cat((cv[:,::down,::down].reshape(B,H1*W1 // down**2,H0,W0), best_match[:,None]),dim=1).softmax(dim=1)
|
158 |
+
pos_embeddings = torch.einsum('bchw,cd->bdhw', P_lowres[:,:-1], grid_lr)
|
159 |
+
pos_embeddings += P_lowres[:,-1] * grid[best_match].permute(0,3,1,2)
|
160 |
+
else:
|
161 |
+
P = corr_volume.reshape(B,H1*W1,H0,W0).softmax(dim=1) # B, HW, H, W
|
162 |
+
pos_embeddings = torch.einsum('bchw,cd->bdhw', P, grid)
|
163 |
+
return pos_embeddings
|
164 |
+
|
165 |
+
def visualize_warp(self, warp, certainty, im_A = None, im_B = None,
|
166 |
+
im_A_path = None, im_B_path = None, symmetric = True, save_path = None, unnormalize = False):
|
167 |
+
device = warp.device
|
168 |
+
H,W2,_ = warp.shape
|
169 |
+
W = W2//2 if symmetric else W2
|
170 |
+
if im_A is None:
|
171 |
+
from PIL import Image
|
172 |
+
im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
|
173 |
+
if not isinstance(im_A, torch.Tensor):
|
174 |
+
im_A = im_A.resize((W,H))
|
175 |
+
im_B = im_B.resize((W,H))
|
176 |
+
x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
|
177 |
+
if symmetric:
|
178 |
+
x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
|
179 |
+
else:
|
180 |
+
if symmetric:
|
181 |
+
x_A = im_A
|
182 |
+
x_B = im_B
|
183 |
+
im_A_transfer_rgb = F.grid_sample(
|
184 |
+
x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
|
185 |
+
)[0]
|
186 |
+
if symmetric:
|
187 |
+
im_B_transfer_rgb = F.grid_sample(
|
188 |
+
x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
|
189 |
+
)[0]
|
190 |
+
warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2)
|
191 |
+
white_im = torch.ones((H,2*W),device=device)
|
192 |
+
else:
|
193 |
+
warp_im = im_A_transfer_rgb
|
194 |
+
white_im = torch.ones((H, W), device = device)
|
195 |
+
vis_im = certainty * warp_im + (1 - certainty) * white_im
|
196 |
+
if save_path is not None:
|
197 |
+
from romatch.utils import tensor_to_pil
|
198 |
+
tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path)
|
199 |
+
return vis_im
|
200 |
+
|
201 |
+
def corr_volume(self, feat0, feat1):
|
202 |
+
"""
|
203 |
+
input:
|
204 |
+
feat0 -> torch.Tensor(B, C, H, W)
|
205 |
+
feat1 -> torch.Tensor(B, C, H, W)
|
206 |
+
return:
|
207 |
+
corr_volume -> torch.Tensor(B, H, W, H, W)
|
208 |
+
"""
|
209 |
+
B, C, H0, W0 = feat0.shape
|
210 |
+
B, C, H1, W1 = feat1.shape
|
211 |
+
feat0 = feat0.view(B, C, H0*W0)
|
212 |
+
feat1 = feat1.view(B, C, H1*W1)
|
213 |
+
corr_volume = torch.einsum('bci,bcj->bji', feat0, feat1).reshape(B, H1, W1, H0 , W0)/math.sqrt(C) #16*16*16
|
214 |
+
return corr_volume
|
215 |
+
|
216 |
+
@torch.inference_mode()
|
217 |
+
def match_from_path(self, im0_path, im1_path):
|
218 |
+
device = self.device
|
219 |
+
im0 = ToTensor()(Image.open(im0_path))[None].to(device)
|
220 |
+
im1 = ToTensor()(Image.open(im1_path))[None].to(device)
|
221 |
+
return self.match(im0, im1, batched = False)
|
222 |
+
|
223 |
+
@torch.inference_mode()
|
224 |
+
def match(self, im0, im1, *args, batched = True):
|
225 |
+
# stupid
|
226 |
+
if isinstance(im0, (str, Path)):
|
227 |
+
return self.match_from_path(im0, im1)
|
228 |
+
elif isinstance(im0, Image.Image):
|
229 |
+
batched = False
|
230 |
+
device = self.device
|
231 |
+
im0 = ToTensor()(im0)[None].to(device)
|
232 |
+
im1 = ToTensor()(im1)[None].to(device)
|
233 |
+
|
234 |
+
B,C,H0,W0 = im0.shape
|
235 |
+
B,C,H1,W1 = im1.shape
|
236 |
+
self.train(False)
|
237 |
+
corresps = self.forward({"im_A":im0, "im_B":im1})
|
238 |
+
#return 1,1
|
239 |
+
flow = F.interpolate(
|
240 |
+
corresps[4]["flow"],
|
241 |
+
size = (H0, W0),
|
242 |
+
mode = "bilinear", align_corners = False).permute(0,2,3,1).reshape(B,H0,W0,2)
|
243 |
+
grid = torch.stack(
|
244 |
+
torch.meshgrid(
|
245 |
+
torch.linspace(-1+1/W0,1-1/W0, W0),
|
246 |
+
torch.linspace(-1+1/H0,1-1/H0, H0),
|
247 |
+
indexing = "xy"),
|
248 |
+
dim = -1).float().to(flow.device).expand(B, H0, W0, 2)
|
249 |
+
|
250 |
+
certainty = F.interpolate(corresps[4]["certainty"], size = (H0,W0), mode = "bilinear", align_corners = False)
|
251 |
+
warp, cert = torch.cat((grid, flow), dim = -1), certainty[:,0].sigmoid()
|
252 |
+
if batched:
|
253 |
+
return warp, cert
|
254 |
+
else:
|
255 |
+
return warp[0], cert[0]
|
256 |
+
|
257 |
+
def sample(
|
258 |
+
self,
|
259 |
+
matches,
|
260 |
+
certainty,
|
261 |
+
num=10000,
|
262 |
+
):
|
263 |
+
if "threshold" in self.sample_mode:
|
264 |
+
upper_thresh = self.sample_thresh
|
265 |
+
certainty = certainty.clone()
|
266 |
+
certainty[certainty > upper_thresh] = 1
|
267 |
+
matches, certainty = (
|
268 |
+
matches.reshape(-1, 4),
|
269 |
+
certainty.reshape(-1),
|
270 |
+
)
|
271 |
+
expansion_factor = 4 if "balanced" in self.sample_mode else 1
|
272 |
+
good_samples = torch.multinomial(certainty,
|
273 |
+
num_samples = min(expansion_factor*num, len(certainty)),
|
274 |
+
replacement=False)
|
275 |
+
good_matches, good_certainty = matches[good_samples], certainty[good_samples]
|
276 |
+
if "balanced" not in self.sample_mode:
|
277 |
+
return good_matches, good_certainty
|
278 |
+
density = kde(good_matches, std=0.1)
|
279 |
+
p = 1 / (density+1)
|
280 |
+
p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
|
281 |
+
balanced_samples = torch.multinomial(p,
|
282 |
+
num_samples = min(num,len(good_certainty)),
|
283 |
+
replacement=False)
|
284 |
+
return good_matches[balanced_samples], good_certainty[balanced_samples]
|
285 |
+
|
286 |
+
def forward(self, batch):
|
287 |
+
"""
|
288 |
+
input:
|
289 |
+
x -> torch.Tensor(B, C, H, W) grayscale or rgb images
|
290 |
+
return:
|
291 |
+
|
292 |
+
"""
|
293 |
+
im0 = batch["im_A"]
|
294 |
+
im1 = batch["im_B"]
|
295 |
+
corresps = {}
|
296 |
+
im0, rh0, rw0 = self.preprocess_tensor(im0)
|
297 |
+
im1, rh1, rw1 = self.preprocess_tensor(im1)
|
298 |
+
B, C, H0, W0 = im0.shape
|
299 |
+
B, C, H1, W1 = im1.shape
|
300 |
+
to_normalized = torch.tensor((2/W1, 2/H1, 1)).to(im0.device)[None,:,None,None]
|
301 |
+
|
302 |
+
if im0.shape[-2:] == im1.shape[-2:]:
|
303 |
+
x = torch.cat([im0, im1], dim=0)
|
304 |
+
x = self.forward_single(x)
|
305 |
+
feats_x0_c, feats_x1_c = x[1].chunk(2)
|
306 |
+
feats_x0_f, feats_x1_f = x[0].chunk(2)
|
307 |
+
else:
|
308 |
+
feats_x0_f, feats_x0_c = self.forward_single(im0)
|
309 |
+
feats_x1_f, feats_x1_c = self.forward_single(im1)
|
310 |
+
corr_volume = self.corr_volume(feats_x0_c, feats_x1_c)
|
311 |
+
coarse_warp = self.pos_embed(corr_volume)
|
312 |
+
coarse_matches = torch.cat((coarse_warp, torch.zeros_like(coarse_warp[:,-1:])), dim=1)
|
313 |
+
feats_x1_c_warped = F.grid_sample(feats_x1_c, coarse_matches.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
|
314 |
+
coarse_matches_delta = self.coarse_matcher(torch.cat((feats_x0_c, feats_x1_c_warped, coarse_warp), dim=1))
|
315 |
+
coarse_matches = coarse_matches + coarse_matches_delta * to_normalized
|
316 |
+
corresps[8] = {"flow": coarse_matches[:,:2], "certainty": coarse_matches[:,2:]}
|
317 |
+
coarse_matches_up = F.interpolate(coarse_matches, size = feats_x0_f.shape[-2:], mode = "bilinear", align_corners = False)
|
318 |
+
coarse_matches_up_detach = coarse_matches_up.detach()#note the detach
|
319 |
+
feats_x1_f_warped = F.grid_sample(feats_x1_f, coarse_matches_up_detach.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
|
320 |
+
fine_matches_delta = self.fine_matcher(torch.cat((feats_x0_f, feats_x1_f_warped, coarse_matches_up_detach[:,:2]), dim=1))
|
321 |
+
fine_matches = coarse_matches_up_detach+fine_matches_delta * to_normalized
|
322 |
+
corresps[4] = {"flow": fine_matches[:,:2], "certainty": fine_matches[:,2:]}
|
323 |
+
return corresps
|
324 |
+
|
325 |
+
|
326 |
+
|
327 |
+
|
328 |
+
|
329 |
+
def train(args):
|
330 |
+
rank = 0
|
331 |
+
gpus = 1
|
332 |
+
device_id = rank % torch.cuda.device_count()
|
333 |
+
romatch.LOCAL_RANK = 0
|
334 |
+
torch.cuda.set_device(device_id)
|
335 |
+
|
336 |
+
resolution = "big"
|
337 |
+
wandb_log = not args.dont_log_wandb
|
338 |
+
experiment_name = Path(__file__).stem
|
339 |
+
wandb_mode = "online" if wandb_log and rank == 0 else "disabled"
|
340 |
+
wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
|
341 |
+
checkpoint_dir = "workspace/checkpoints/"
|
342 |
+
h,w = resolutions[resolution]
|
343 |
+
model = XFeatModel(freeze_xfeat = False).to(device_id)
|
344 |
+
# Num steps
|
345 |
+
global_step = 0
|
346 |
+
batch_size = args.gpu_batch_size
|
347 |
+
step_size = gpus*batch_size
|
348 |
+
romatch.STEP_SIZE = step_size
|
349 |
+
|
350 |
+
N = 2_000_000 # 2M pairs
|
351 |
+
# checkpoint every
|
352 |
+
k = 25000 // romatch.STEP_SIZE
|
353 |
+
|
354 |
+
# Data
|
355 |
+
mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
|
356 |
+
use_horizontal_flip_aug = True
|
357 |
+
normalize = False # don't imgnet normalize
|
358 |
+
rot_prob = 0
|
359 |
+
depth_interpolation_mode = "bilinear"
|
360 |
+
megadepth_train1 = mega.build_scenes(
|
361 |
+
split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
|
362 |
+
ht=h,wt=w, normalize = normalize
|
363 |
+
)
|
364 |
+
megadepth_train2 = mega.build_scenes(
|
365 |
+
split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
|
366 |
+
ht=h,wt=w, normalize = normalize
|
367 |
+
)
|
368 |
+
megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
|
369 |
+
mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
|
370 |
+
# Loss and optimizer
|
371 |
+
depth_loss = RobustLosses(
|
372 |
+
ce_weight=0.01,
|
373 |
+
local_dist={4:4},
|
374 |
+
depth_interpolation_mode=depth_interpolation_mode,
|
375 |
+
alpha = {4:0.15, 8:0.15},
|
376 |
+
c = 1e-4,
|
377 |
+
epe_mask_prob_th = 0.001,
|
378 |
+
)
|
379 |
+
parameters = [
|
380 |
+
{"params": model.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8},
|
381 |
+
]
|
382 |
+
optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
|
383 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
384 |
+
optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10])
|
385 |
+
#megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
|
386 |
+
mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 30)
|
387 |
+
|
388 |
+
checkpointer = CheckPoint(checkpoint_dir, experiment_name)
|
389 |
+
model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
|
390 |
+
romatch.GLOBAL_STEP = global_step
|
391 |
+
grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
|
392 |
+
grad_clip_norm = 0.01
|
393 |
+
#megadense_benchmark.benchmark(model)
|
394 |
+
for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE):
|
395 |
+
mega_sampler = torch.utils.data.WeightedRandomSampler(
|
396 |
+
mega_ws, num_samples = batch_size * k, replacement=False
|
397 |
+
)
|
398 |
+
mega_dataloader = iter(
|
399 |
+
torch.utils.data.DataLoader(
|
400 |
+
megadepth_train,
|
401 |
+
batch_size = batch_size,
|
402 |
+
sampler = mega_sampler,
|
403 |
+
num_workers = 8,
|
404 |
+
)
|
405 |
+
)
|
406 |
+
train_k_steps(
|
407 |
+
n, k, mega_dataloader, model, depth_loss, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm,
|
408 |
+
)
|
409 |
+
checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP)
|
410 |
+
wandb.log(mega1500_benchmark.benchmark(model, model_name=experiment_name), step = romatch.GLOBAL_STEP)
|
411 |
+
|
412 |
+
def test_mega_8_scenes(model, name):
|
413 |
+
mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
|
414 |
+
scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
|
415 |
+
'mega_8_scenes_0025_0.1_0.3.npz',
|
416 |
+
'mega_8_scenes_0021_0.1_0.3.npz',
|
417 |
+
'mega_8_scenes_0008_0.1_0.3.npz',
|
418 |
+
'mega_8_scenes_0032_0.1_0.3.npz',
|
419 |
+
'mega_8_scenes_1589_0.1_0.3.npz',
|
420 |
+
'mega_8_scenes_0063_0.1_0.3.npz',
|
421 |
+
'mega_8_scenes_0024_0.1_0.3.npz',
|
422 |
+
'mega_8_scenes_0019_0.3_0.5.npz',
|
423 |
+
'mega_8_scenes_0025_0.3_0.5.npz',
|
424 |
+
'mega_8_scenes_0021_0.3_0.5.npz',
|
425 |
+
'mega_8_scenes_0008_0.3_0.5.npz',
|
426 |
+
'mega_8_scenes_0032_0.3_0.5.npz',
|
427 |
+
'mega_8_scenes_1589_0.3_0.5.npz',
|
428 |
+
'mega_8_scenes_0063_0.3_0.5.npz',
|
429 |
+
'mega_8_scenes_0024_0.3_0.5.npz'])
|
430 |
+
mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
|
431 |
+
print(mega_8_scenes_results)
|
432 |
+
json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
|
433 |
+
|
434 |
+
def test_mega1500(model, name):
|
435 |
+
mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
|
436 |
+
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
|
437 |
+
json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
|
438 |
+
|
439 |
+
def test_mega1500_poselib(model, name):
|
440 |
+
mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1)
|
441 |
+
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
|
442 |
+
json.dump(mega1500_results, open(f"results/mega1500_poselib_{name}.json", "w"))
|
443 |
+
|
444 |
+
def test_mega_8_scenes_poselib(model, name):
|
445 |
+
mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1,
|
446 |
+
scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
|
447 |
+
'mega_8_scenes_0025_0.1_0.3.npz',
|
448 |
+
'mega_8_scenes_0021_0.1_0.3.npz',
|
449 |
+
'mega_8_scenes_0008_0.1_0.3.npz',
|
450 |
+
'mega_8_scenes_0032_0.1_0.3.npz',
|
451 |
+
'mega_8_scenes_1589_0.1_0.3.npz',
|
452 |
+
'mega_8_scenes_0063_0.1_0.3.npz',
|
453 |
+
'mega_8_scenes_0024_0.1_0.3.npz',
|
454 |
+
'mega_8_scenes_0019_0.3_0.5.npz',
|
455 |
+
'mega_8_scenes_0025_0.3_0.5.npz',
|
456 |
+
'mega_8_scenes_0021_0.3_0.5.npz',
|
457 |
+
'mega_8_scenes_0008_0.3_0.5.npz',
|
458 |
+
'mega_8_scenes_0032_0.3_0.5.npz',
|
459 |
+
'mega_8_scenes_1589_0.3_0.5.npz',
|
460 |
+
'mega_8_scenes_0063_0.3_0.5.npz',
|
461 |
+
'mega_8_scenes_0024_0.3_0.5.npz'])
|
462 |
+
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
|
463 |
+
json.dump(mega1500_results, open(f"results/mega_8_scenes_poselib_{name}.json", "w"))
|
464 |
+
|
465 |
+
def test_scannet_poselib(model, name):
|
466 |
+
scannet_benchmark = ScanNetPoselibBenchmark("data/scannet")
|
467 |
+
scannet_results = scannet_benchmark.benchmark(model)
|
468 |
+
json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
|
469 |
+
|
470 |
+
def test_scannet(model, name):
|
471 |
+
scannet_benchmark = ScanNetBenchmark("data/scannet")
|
472 |
+
scannet_results = scannet_benchmark.benchmark(model)
|
473 |
+
json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
|
474 |
+
|
475 |
+
if __name__ == "__main__":
|
476 |
+
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
|
477 |
+
os.environ["OMP_NUM_THREADS"] = "16"
|
478 |
+
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
479 |
+
import romatch
|
480 |
+
parser = ArgumentParser()
|
481 |
+
parser.add_argument("--only_test", action='store_true')
|
482 |
+
parser.add_argument("--debug_mode", action='store_true')
|
483 |
+
parser.add_argument("--dont_log_wandb", action='store_true')
|
484 |
+
parser.add_argument("--train_resolution", default='medium')
|
485 |
+
parser.add_argument("--gpu_batch_size", default=8, type=int)
|
486 |
+
parser.add_argument("--wandb_entity", required = False)
|
487 |
+
|
488 |
+
args, _ = parser.parse_known_args()
|
489 |
+
romatch.DEBUG_MODE = args.debug_mode
|
490 |
+
if not args.only_test:
|
491 |
+
train(args)
|
492 |
+
|
493 |
+
experiment_name = "tiny_roma_v1_outdoor"#Path(__file__).stem
|
494 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
495 |
+
model = XFeatModel(freeze_xfeat=False, exact_softmax=False).to(device)
|
496 |
+
model.load_state_dict(torch.load(f"{experiment_name}.pth"))
|
497 |
+
test_mega1500_poselib(model, experiment_name)
|
498 |
+
|
submodules/RoMa/requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
einops
|
3 |
+
torchvision
|
4 |
+
opencv-python
|
5 |
+
kornia
|
6 |
+
albumentations
|
7 |
+
loguru
|
8 |
+
tqdm
|
9 |
+
matplotlib
|
10 |
+
h5py
|
11 |
+
wandb
|
12 |
+
timm
|
13 |
+
poselib
|
14 |
+
#xformers # Optional, used for memefficient attention
|
submodules/RoMa/romatch/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from .models import roma_outdoor, tiny_roma_v1_outdoor, roma_indoor
|
3 |
+
|
4 |
+
DEBUG_MODE = False
|
5 |
+
RANK = int(os.environ.get('RANK', default = 0))
|
6 |
+
GLOBAL_STEP = 0
|
7 |
+
STEP_SIZE = 1
|
8 |
+
LOCAL_RANK = -1
|
submodules/RoMa/romatch/benchmarks/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .hpatches_sequences_homog_benchmark import HpatchesHomogBenchmark
|
2 |
+
from .scannet_benchmark import ScanNetBenchmark
|
3 |
+
from .megadepth_pose_estimation_benchmark import MegaDepthPoseEstimationBenchmark
|
4 |
+
from .megadepth_dense_benchmark import MegadepthDenseBenchmark
|
5 |
+
from .megadepth_pose_estimation_benchmark_poselib import Mega1500PoseLibBenchmark
|
6 |
+
#from .scannet_benchmark_poselib import ScanNetPoselibBenchmark
|
submodules/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import os
|
5 |
+
|
6 |
+
from tqdm import tqdm
|
7 |
+
from romatch.utils import pose_auc
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
|
11 |
+
class HpatchesHomogBenchmark:
|
12 |
+
"""Hpatches grid goes from [0,n-1] instead of [0.5,n-0.5]"""
|
13 |
+
|
14 |
+
def __init__(self, dataset_path) -> None:
|
15 |
+
seqs_dir = "hpatches-sequences-release"
|
16 |
+
self.seqs_path = os.path.join(dataset_path, seqs_dir)
|
17 |
+
self.seq_names = sorted(os.listdir(self.seqs_path))
|
18 |
+
# Ignore seqs is same as LoFTR.
|
19 |
+
self.ignore_seqs = set(
|
20 |
+
[
|
21 |
+
"i_contruction",
|
22 |
+
"i_crownnight",
|
23 |
+
"i_dc",
|
24 |
+
"i_pencils",
|
25 |
+
"i_whitebuilding",
|
26 |
+
"v_artisans",
|
27 |
+
"v_astronautis",
|
28 |
+
"v_talent",
|
29 |
+
]
|
30 |
+
)
|
31 |
+
|
32 |
+
def convert_coordinates(self, im_A_coords, im_A_to_im_B, wq, hq, wsup, hsup):
|
33 |
+
offset = 0.5 # Hpatches assumes that the center of the top-left pixel is at [0,0] (I think)
|
34 |
+
im_A_coords = (
|
35 |
+
np.stack(
|
36 |
+
(
|
37 |
+
wq * (im_A_coords[..., 0] + 1) / 2,
|
38 |
+
hq * (im_A_coords[..., 1] + 1) / 2,
|
39 |
+
),
|
40 |
+
axis=-1,
|
41 |
+
)
|
42 |
+
- offset
|
43 |
+
)
|
44 |
+
im_A_to_im_B = (
|
45 |
+
np.stack(
|
46 |
+
(
|
47 |
+
wsup * (im_A_to_im_B[..., 0] + 1) / 2,
|
48 |
+
hsup * (im_A_to_im_B[..., 1] + 1) / 2,
|
49 |
+
),
|
50 |
+
axis=-1,
|
51 |
+
)
|
52 |
+
- offset
|
53 |
+
)
|
54 |
+
return im_A_coords, im_A_to_im_B
|
55 |
+
|
56 |
+
def benchmark(self, model, model_name = None):
|
57 |
+
n_matches = []
|
58 |
+
homog_dists = []
|
59 |
+
for seq_idx, seq_name in tqdm(
|
60 |
+
enumerate(self.seq_names), total=len(self.seq_names)
|
61 |
+
):
|
62 |
+
im_A_path = os.path.join(self.seqs_path, seq_name, "1.ppm")
|
63 |
+
im_A = Image.open(im_A_path)
|
64 |
+
w1, h1 = im_A.size
|
65 |
+
for im_idx in range(2, 7):
|
66 |
+
im_B_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm")
|
67 |
+
im_B = Image.open(im_B_path)
|
68 |
+
w2, h2 = im_B.size
|
69 |
+
H = np.loadtxt(
|
70 |
+
os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
|
71 |
+
)
|
72 |
+
dense_matches, dense_certainty = model.match(
|
73 |
+
im_A_path, im_B_path
|
74 |
+
)
|
75 |
+
good_matches, _ = model.sample(dense_matches, dense_certainty, 5000)
|
76 |
+
pos_a, pos_b = self.convert_coordinates(
|
77 |
+
good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2
|
78 |
+
)
|
79 |
+
try:
|
80 |
+
H_pred, inliers = cv2.findHomography(
|
81 |
+
pos_a,
|
82 |
+
pos_b,
|
83 |
+
method = cv2.RANSAC,
|
84 |
+
confidence = 0.99999,
|
85 |
+
ransacReprojThreshold = 3 * min(w2, h2) / 480,
|
86 |
+
)
|
87 |
+
except:
|
88 |
+
H_pred = None
|
89 |
+
if H_pred is None:
|
90 |
+
H_pred = np.zeros((3, 3))
|
91 |
+
H_pred[2, 2] = 1.0
|
92 |
+
corners = np.array(
|
93 |
+
[[0, 0, 1], [0, h1 - 1, 1], [w1 - 1, 0, 1], [w1 - 1, h1 - 1, 1]]
|
94 |
+
)
|
95 |
+
real_warped_corners = np.dot(corners, np.transpose(H))
|
96 |
+
real_warped_corners = (
|
97 |
+
real_warped_corners[:, :2] / real_warped_corners[:, 2:]
|
98 |
+
)
|
99 |
+
warped_corners = np.dot(corners, np.transpose(H_pred))
|
100 |
+
warped_corners = warped_corners[:, :2] / warped_corners[:, 2:]
|
101 |
+
mean_dist = np.mean(
|
102 |
+
np.linalg.norm(real_warped_corners - warped_corners, axis=1)
|
103 |
+
) / (min(w2, h2) / 480.0)
|
104 |
+
homog_dists.append(mean_dist)
|
105 |
+
|
106 |
+
n_matches = np.array(n_matches)
|
107 |
+
thresholds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
108 |
+
auc = pose_auc(np.array(homog_dists), thresholds)
|
109 |
+
return {
|
110 |
+
"hpatches_homog_auc_3": auc[2],
|
111 |
+
"hpatches_homog_auc_5": auc[4],
|
112 |
+
"hpatches_homog_auc_10": auc[9],
|
113 |
+
}
|
submodules/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import tqdm
|
4 |
+
from romatch.datasets import MegadepthBuilder
|
5 |
+
from romatch.utils import warp_kpts
|
6 |
+
from torch.utils.data import ConcatDataset
|
7 |
+
import romatch
|
8 |
+
|
9 |
+
class MegadepthDenseBenchmark:
|
10 |
+
def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000) -> None:
|
11 |
+
mega = MegadepthBuilder(data_root=data_root)
|
12 |
+
self.dataset = ConcatDataset(
|
13 |
+
mega.build_scenes(split="test_loftr", ht=h, wt=w)
|
14 |
+
) # fixed resolution of 384,512
|
15 |
+
self.num_samples = num_samples
|
16 |
+
|
17 |
+
def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches):
|
18 |
+
b, h1, w1, d = dense_matches.shape
|
19 |
+
with torch.no_grad():
|
20 |
+
x1 = dense_matches[..., :2].reshape(b, h1 * w1, 2)
|
21 |
+
mask, x2 = warp_kpts(
|
22 |
+
x1.double(),
|
23 |
+
depth1.double(),
|
24 |
+
depth2.double(),
|
25 |
+
T_1to2.double(),
|
26 |
+
K1.double(),
|
27 |
+
K2.double(),
|
28 |
+
)
|
29 |
+
x2 = torch.stack(
|
30 |
+
(w1 * (x2[..., 0] + 1) / 2, h1 * (x2[..., 1] + 1) / 2), dim=-1
|
31 |
+
)
|
32 |
+
prob = mask.float().reshape(b, h1, w1)
|
33 |
+
x2_hat = dense_matches[..., 2:]
|
34 |
+
x2_hat = torch.stack(
|
35 |
+
(w1 * (x2_hat[..., 0] + 1) / 2, h1 * (x2_hat[..., 1] + 1) / 2), dim=-1
|
36 |
+
)
|
37 |
+
gd = (x2_hat - x2.reshape(b, h1, w1, 2)).norm(dim=-1)
|
38 |
+
gd = gd[prob == 1]
|
39 |
+
pck_1 = (gd < 1.0).float().mean()
|
40 |
+
pck_3 = (gd < 3.0).float().mean()
|
41 |
+
pck_5 = (gd < 5.0).float().mean()
|
42 |
+
return gd, pck_1, pck_3, pck_5, prob
|
43 |
+
|
44 |
+
def benchmark(self, model, batch_size=8):
|
45 |
+
model.train(False)
|
46 |
+
with torch.no_grad():
|
47 |
+
gd_tot = 0.0
|
48 |
+
pck_1_tot = 0.0
|
49 |
+
pck_3_tot = 0.0
|
50 |
+
pck_5_tot = 0.0
|
51 |
+
sampler = torch.utils.data.WeightedRandomSampler(
|
52 |
+
torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples
|
53 |
+
)
|
54 |
+
B = batch_size
|
55 |
+
dataloader = torch.utils.data.DataLoader(
|
56 |
+
self.dataset, batch_size=B, num_workers=batch_size, sampler=sampler
|
57 |
+
)
|
58 |
+
for idx, data in tqdm.tqdm(enumerate(dataloader), disable = romatch.RANK > 0):
|
59 |
+
im_A, im_B, depth1, depth2, T_1to2, K1, K2 = (
|
60 |
+
data["im_A"].cuda(),
|
61 |
+
data["im_B"].cuda(),
|
62 |
+
data["im_A_depth"].cuda(),
|
63 |
+
data["im_B_depth"].cuda(),
|
64 |
+
data["T_1to2"].cuda(),
|
65 |
+
data["K1"].cuda(),
|
66 |
+
data["K2"].cuda(),
|
67 |
+
)
|
68 |
+
matches, certainty = model.match(im_A, im_B, batched=True)
|
69 |
+
gd, pck_1, pck_3, pck_5, prob = self.geometric_dist(
|
70 |
+
depth1, depth2, T_1to2, K1, K2, matches
|
71 |
+
)
|
72 |
+
if romatch.DEBUG_MODE:
|
73 |
+
from romatch.utils.utils import tensor_to_pil
|
74 |
+
import torch.nn.functional as F
|
75 |
+
path = "vis"
|
76 |
+
H, W = model.get_output_resolution()
|
77 |
+
white_im = torch.ones((B,1,H,W),device="cuda")
|
78 |
+
im_B_transfer_rgb = F.grid_sample(
|
79 |
+
im_B.cuda(), matches[:,:,:W, 2:], mode="bilinear", align_corners=False
|
80 |
+
)
|
81 |
+
warp_im = im_B_transfer_rgb
|
82 |
+
c_b = certainty[:,None]#(certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None]
|
83 |
+
vis_im = c_b * warp_im + (1 - c_b) * white_im
|
84 |
+
for b in range(B):
|
85 |
+
import os
|
86 |
+
os.makedirs(f"{path}/{model.name}/{idx}_{b}_{H}_{W}",exist_ok=True)
|
87 |
+
tensor_to_pil(vis_im[b], unnormalize=True).save(
|
88 |
+
f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg")
|
89 |
+
tensor_to_pil(im_A[b].cuda(), unnormalize=True).save(
|
90 |
+
f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg")
|
91 |
+
tensor_to_pil(im_B[b].cuda(), unnormalize=True).save(
|
92 |
+
f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg")
|
93 |
+
|
94 |
+
|
95 |
+
gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = (
|
96 |
+
gd_tot + gd.mean(),
|
97 |
+
pck_1_tot + pck_1,
|
98 |
+
pck_3_tot + pck_3,
|
99 |
+
pck_5_tot + pck_5,
|
100 |
+
)
|
101 |
+
return {
|
102 |
+
"epe": gd_tot.item() / len(dataloader),
|
103 |
+
"mega_pck_1": pck_1_tot.item() / len(dataloader),
|
104 |
+
"mega_pck_3": pck_3_tot.item() / len(dataloader),
|
105 |
+
"mega_pck_5": pck_5_tot.item() / len(dataloader),
|
106 |
+
}
|
submodules/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from romatch.utils import *
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import romatch
|
8 |
+
import kornia.geometry.epipolar as kepi
|
9 |
+
|
10 |
+
class MegaDepthPoseEstimationBenchmark:
|
11 |
+
def __init__(self, data_root="data/megadepth", scene_names = None) -> None:
|
12 |
+
if scene_names is None:
|
13 |
+
self.scene_names = [
|
14 |
+
"0015_0.1_0.3.npz",
|
15 |
+
"0015_0.3_0.5.npz",
|
16 |
+
"0022_0.1_0.3.npz",
|
17 |
+
"0022_0.3_0.5.npz",
|
18 |
+
"0022_0.5_0.7.npz",
|
19 |
+
]
|
20 |
+
else:
|
21 |
+
self.scene_names = scene_names
|
22 |
+
self.scenes = [
|
23 |
+
np.load(f"{data_root}/{scene}", allow_pickle=True)
|
24 |
+
for scene in self.scene_names
|
25 |
+
]
|
26 |
+
self.data_root = data_root
|
27 |
+
|
28 |
+
def benchmark(self, model, model_name = None):
|
29 |
+
with torch.no_grad():
|
30 |
+
data_root = self.data_root
|
31 |
+
tot_e_t, tot_e_R, tot_e_pose = [], [], []
|
32 |
+
thresholds = [5, 10, 20]
|
33 |
+
for scene_ind in range(len(self.scenes)):
|
34 |
+
import os
|
35 |
+
scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
|
36 |
+
scene = self.scenes[scene_ind]
|
37 |
+
pairs = scene["pair_infos"]
|
38 |
+
intrinsics = scene["intrinsics"]
|
39 |
+
poses = scene["poses"]
|
40 |
+
im_paths = scene["image_paths"]
|
41 |
+
pair_inds = range(len(pairs))
|
42 |
+
for pairind in tqdm(pair_inds):
|
43 |
+
idx1, idx2 = pairs[pairind][0]
|
44 |
+
K1 = intrinsics[idx1].copy()
|
45 |
+
T1 = poses[idx1].copy()
|
46 |
+
R1, t1 = T1[:3, :3], T1[:3, 3]
|
47 |
+
K2 = intrinsics[idx2].copy()
|
48 |
+
T2 = poses[idx2].copy()
|
49 |
+
R2, t2 = T2[:3, :3], T2[:3, 3]
|
50 |
+
R, t = compute_relative_pose(R1, t1, R2, t2)
|
51 |
+
T1_to_2 = np.concatenate((R,t[:,None]), axis=-1)
|
52 |
+
im_A_path = f"{data_root}/{im_paths[idx1]}"
|
53 |
+
im_B_path = f"{data_root}/{im_paths[idx2]}"
|
54 |
+
dense_matches, dense_certainty = model.match(
|
55 |
+
im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy()
|
56 |
+
)
|
57 |
+
sparse_matches,_ = model.sample(
|
58 |
+
dense_matches, dense_certainty, 5_000
|
59 |
+
)
|
60 |
+
|
61 |
+
im_A = Image.open(im_A_path)
|
62 |
+
w1, h1 = im_A.size
|
63 |
+
im_B = Image.open(im_B_path)
|
64 |
+
w2, h2 = im_B.size
|
65 |
+
if True: # Note: we keep this true as it was used in DKM/RoMa papers. There is very little difference compared to setting to False.
|
66 |
+
scale1 = 1200 / max(w1, h1)
|
67 |
+
scale2 = 1200 / max(w2, h2)
|
68 |
+
w1, h1 = scale1 * w1, scale1 * h1
|
69 |
+
w2, h2 = scale2 * w2, scale2 * h2
|
70 |
+
K1, K2 = K1.copy(), K2.copy()
|
71 |
+
K1[:2] = K1[:2] * scale1
|
72 |
+
K2[:2] = K2[:2] * scale2
|
73 |
+
|
74 |
+
kpts1, kpts2 = model.to_pixel_coordinates(sparse_matches, h1, w1, h2, w2)
|
75 |
+
kpts1, kpts2 = kpts1.cpu().numpy(), kpts2.cpu().numpy()
|
76 |
+
for _ in range(5):
|
77 |
+
shuffling = np.random.permutation(np.arange(len(kpts1)))
|
78 |
+
kpts1 = kpts1[shuffling]
|
79 |
+
kpts2 = kpts2[shuffling]
|
80 |
+
try:
|
81 |
+
threshold = 0.5
|
82 |
+
norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
|
83 |
+
R_est, t_est, mask = estimate_pose(
|
84 |
+
kpts1,
|
85 |
+
kpts2,
|
86 |
+
K1,
|
87 |
+
K2,
|
88 |
+
norm_threshold,
|
89 |
+
conf=0.99999,
|
90 |
+
)
|
91 |
+
T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
|
92 |
+
e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
|
93 |
+
e_pose = max(e_t, e_R)
|
94 |
+
except Exception as e:
|
95 |
+
print(repr(e))
|
96 |
+
e_t, e_R = 90, 90
|
97 |
+
e_pose = max(e_t, e_R)
|
98 |
+
tot_e_t.append(e_t)
|
99 |
+
tot_e_R.append(e_R)
|
100 |
+
tot_e_pose.append(e_pose)
|
101 |
+
tot_e_pose = np.array(tot_e_pose)
|
102 |
+
auc = pose_auc(tot_e_pose, thresholds)
|
103 |
+
acc_5 = (tot_e_pose < 5).mean()
|
104 |
+
acc_10 = (tot_e_pose < 10).mean()
|
105 |
+
acc_15 = (tot_e_pose < 15).mean()
|
106 |
+
acc_20 = (tot_e_pose < 20).mean()
|
107 |
+
map_5 = acc_5
|
108 |
+
map_10 = np.mean([acc_5, acc_10])
|
109 |
+
map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
|
110 |
+
print(f"{model_name} auc: {auc}")
|
111 |
+
return {
|
112 |
+
"auc_5": auc[0],
|
113 |
+
"auc_10": auc[1],
|
114 |
+
"auc_20": auc[2],
|
115 |
+
"map_5": map_5,
|
116 |
+
"map_10": map_10,
|
117 |
+
"map_20": map_20,
|
118 |
+
}
|