Olga commited on
Commit
5f9d349
·
0 Parent(s):

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +37 -0
  2. .gitignore +159 -0
  3. LICENSE.txt +12 -0
  4. README.md +14 -0
  5. app.py +434 -0
  6. assets/examples/video_bakery.mp4 +3 -0
  7. assets/examples/video_flowers.mp4 +3 -0
  8. assets/examples/video_fruits.mp4 +3 -0
  9. assets/examples/video_plant.mp4 +3 -0
  10. assets/examples/video_salad.mp4 +3 -0
  11. assets/examples/video_tram.mp4 +3 -0
  12. assets/examples/video_tulips.mp4 +3 -0
  13. assets/video_fruits_ours_full.mp4 +3 -0
  14. configs/gs/base.yaml +51 -0
  15. configs/train.yaml +38 -0
  16. requirements.txt +32 -0
  17. source/EDGS.code-workspace +11 -0
  18. source/__init__.py +0 -0
  19. source/corr_init.py +682 -0
  20. source/corr_init_new.py +904 -0
  21. source/data_utils.py +28 -0
  22. source/losses.py +100 -0
  23. source/networks.py +52 -0
  24. source/timer.py +24 -0
  25. source/trainer.py +262 -0
  26. source/utils_aux.py +92 -0
  27. source/utils_preprocess.py +334 -0
  28. source/vggt_to_colmap.py +598 -0
  29. source/visualization.py +1072 -0
  30. submodules/RoMa/.gitignore +11 -0
  31. submodules/RoMa/LICENSE +21 -0
  32. submodules/RoMa/README.md +123 -0
  33. submodules/RoMa/data/.gitignore +2 -0
  34. submodules/RoMa/demo/demo_3D_effect.py +47 -0
  35. submodules/RoMa/demo/demo_fundamental.py +34 -0
  36. submodules/RoMa/demo/demo_match.py +50 -0
  37. submodules/RoMa/demo/demo_match_opencv_sift.py +43 -0
  38. submodules/RoMa/demo/demo_match_tiny.py +77 -0
  39. submodules/RoMa/demo/gif/.gitignore +2 -0
  40. submodules/RoMa/experiments/eval_roma_outdoor.py +57 -0
  41. submodules/RoMa/experiments/eval_tiny_roma_v1_outdoor.py +84 -0
  42. submodules/RoMa/experiments/roma_indoor.py +320 -0
  43. submodules/RoMa/experiments/train_roma_outdoor.py +307 -0
  44. submodules/RoMa/experiments/train_tiny_roma_v1_outdoor.py +498 -0
  45. submodules/RoMa/requirements.txt +14 -0
  46. submodules/RoMa/romatch/__init__.py +8 -0
  47. submodules/RoMa/romatch/benchmarks/__init__.py +6 -0
  48. submodules/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py +113 -0
  49. submodules/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py +106 -0
  50. submodules/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py +118 -0
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.whl filter=lfs diff=lfs merge=lfs -text
37
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Local notebooks used for local debug
2
+ notebooks_local/
3
+ wandb/
4
+
5
+ # Gradio files
6
+
7
+ served_files/
8
+
9
+ # hidden folders
10
+ .*/**
11
+
12
+ # The rest is taken from https://github.com/Anttwo/SuGaR
13
+ *.pt
14
+ *.pth
15
+ output*
16
+ *.slurm
17
+ *.pyc
18
+ *.ply
19
+ *.obj
20
+ sugar_tests.ipynb
21
+ sugar_sh_tests.ipynb
22
+
23
+ # To remove
24
+ frosting*
25
+ extract_shell.py
26
+ train_frosting_refined.py
27
+ train_frosting.py
28
+ run_frosting_viewer.py
29
+ slurm_a100.sh
30
+
31
+ # Byte-compiled / optimized / DLL files
32
+ __pycache__/
33
+ *.py[cod]
34
+ *$py.class
35
+
36
+ # C extensions
37
+ *.so
38
+
39
+ # Distribution / packaging
40
+ .Python
41
+ build/
42
+ develop-eggs/
43
+ dist/
44
+ downloads/
45
+ eggs/
46
+ .eggs/
47
+ lib/
48
+ lib64/
49
+ parts/
50
+ sdist/
51
+ var/
52
+ pip-wheel-metadata/
53
+ share/python-wheels/
54
+ *.egg-info/
55
+ .installed.cfg
56
+ *.egg
57
+ MANIFEST
58
+
59
+ # PyInstaller
60
+ # Usually these files are written by a python script from a template
61
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
62
+ *.manifest
63
+ *.spec
64
+
65
+ # Installer logs
66
+ pip-log.txt
67
+ pip-delete-this-directory.txt
68
+
69
+ # Unit test / coverage reports
70
+ htmlcov/
71
+ .tox/
72
+ .nox/
73
+ .coverage
74
+ .coverage.*
75
+ .cache
76
+ nosetests.xml
77
+ coverage.xml
78
+ *.cover
79
+ *.py,cover
80
+ .hypothesis/
81
+ .pytest_cache/
82
+
83
+ # Translations
84
+ *.mo
85
+ *.pot
86
+
87
+ # Django stuff:
88
+ *.log
89
+ local_settings.py
90
+ db.sqlite3
91
+ db.sqlite3-journal
92
+
93
+ # Flask stuff:
94
+ instance/
95
+ .webassets-cache
96
+
97
+ # Scrapy stuff:
98
+ .scrapy
99
+
100
+ # Sphinx documentation
101
+ docs/_build/
102
+
103
+ # PyBuilder
104
+ target/
105
+
106
+ # Jupyter Notebook
107
+ .ipynb_checkpoints
108
+
109
+ # IPython
110
+ profile_default/
111
+ ipython_config.py
112
+
113
+ # pyenv
114
+ .python-version
115
+
116
+ # pipenv
117
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
118
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
119
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
120
+ # install all needed dependencies.
121
+ #Pipfile.lock
122
+
123
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
124
+ __pypackages__/
125
+
126
+ # Celery stuff
127
+ celerybeat-schedule
128
+ celerybeat.pid
129
+
130
+ # SageMath parsed files
131
+ *.sage.py
132
+
133
+ # Environments
134
+ .env
135
+ .venv
136
+ env/
137
+ venv/
138
+ ENV/
139
+ env.bak/
140
+ venv.bak/
141
+
142
+ # Spyder project settings
143
+ .spyderproject
144
+ .spyproject
145
+
146
+ # Rope project settings
147
+ .ropeproject
148
+
149
+ # mkdocs documentation
150
+ /site
151
+
152
+ # mypy
153
+ .mypy_cache/
154
+ .dmypy.json
155
+ dmypy.json
156
+
157
+ # Pyre type checker
158
+ .pyre/
159
+ learnableearthparser/fast_sampler/_sampler.c
LICENSE.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright 2025, Dmytro Kotovenko, Olga Grebenkova, Björn Ommer
2
+ Redistribution and use in source and binary forms, with or without modification, are permitted for non-commercial academic research and/or non-commercial personal use only provided that the following conditions are met:
3
+
4
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
5
+
6
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
7
+
8
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
9
+
10
+ Any use of this software beyond the above specified conditions requires a separate license. Please contact the copyright holders to discuss license terms.
11
+
12
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: EDGS
3
+ emoji: 🎥
4
+ colorFrom: pink
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 5.25.2
8
+ app_file: app.py
9
+ pinned: false
10
+ python_version: "3.10"
11
+ license: other
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import torch
3
+ import os
4
+ import shutil
5
+ import tempfile
6
+ import argparse
7
+ import gradio as gr
8
+ import sys
9
+ import io
10
+ import subprocess
11
+ from PIL import Image
12
+ import numpy as np
13
+ from hydra import initialize, compose
14
+ import hydra
15
+ from omegaconf import OmegaConf
16
+ import time
17
+
18
+ def install_submodules():
19
+ subprocess.check_call(['pip', 'install', './submodules/RoMa'])
20
+
21
+
22
+
23
+ STATIC_FILE_SERVING_FOLDER = "./served_files"
24
+ MODEL_PATH = None
25
+ os.makedirs(STATIC_FILE_SERVING_FOLDER, exist_ok=True)
26
+
27
+ trainer = None
28
+
29
+ class Tee(io.TextIOBase):
30
+
31
+ def __init__(self, *streams):
32
+ self.streams = streams
33
+
34
+ def write(self, data):
35
+ for stream in self.streams:
36
+ stream.write(data)
37
+ return len(data)
38
+
39
+ def flush(self):
40
+ for stream in self.streams:
41
+ stream.flush()
42
+
43
+
44
+ def capture_logs(func, *args, **kwargs):
45
+ log_capture_string = io.StringIO()
46
+ tee = Tee(sys.__stdout__, log_capture_string)
47
+
48
+ with contextlib.redirect_stdout(tee):
49
+ result = func(*args, **kwargs)
50
+
51
+ return result, log_capture_string.getvalue()
52
+
53
+
54
+ @spaces.GPU(duration=350)
55
+ # Training Pipeline
56
+ def run_training_pipeline(scene_dir,
57
+ num_ref_views=16,
58
+ num_corrs_per_view=20000,
59
+ num_steps=1_000,
60
+ mode_toggle="Ours (EDGS)"):
61
+ with initialize(config_path="./configs", version_base="1.1"):
62
+ cfg = compose(config_name="train")
63
+
64
+ scene_name = os.path.basename(scene_dir)
65
+ model_output_dir = f"./outputs/{scene_name}_trained"
66
+
67
+ cfg.wandb.mode = "disabled"
68
+ cfg.gs.dataset.model_path = model_output_dir
69
+ cfg.gs.dataset.source_path = scene_dir
70
+ cfg.gs.dataset.images = "images"
71
+
72
+ cfg.gs.opt.TEST_CAM_IDX_TO_LOG = 12
73
+ cfg.train.gs_epochs = 30000
74
+
75
+ if mode_toggle=="Ours (EDGS)":
76
+ cfg.gs.opt.opacity_reset_interval = 1_000_000
77
+ cfg.train.reduce_opacity = True
78
+ cfg.train.no_densify = True
79
+ cfg.train.max_lr = True
80
+
81
+ cfg.init_wC.use = True
82
+ cfg.init_wC.matches_per_ref = num_corrs_per_view
83
+ cfg.init_wC.nns_per_ref = 1
84
+ cfg.init_wC.num_refs = num_ref_views
85
+ cfg.init_wC.add_SfM_init = False
86
+ cfg.init_wC.scaling_factor = 0.00077 * 2.
87
+
88
+ set_seed(cfg.seed)
89
+ os.makedirs(cfg.gs.dataset.model_path, exist_ok=True)
90
+
91
+ global trainer
92
+ global MODEL_PATH
93
+ generator3dgs = hydra.utils.instantiate(cfg.gs, do_train_test_split=False)
94
+ trainer = EDGSTrainer(GS=generator3dgs, training_config=cfg.gs.opt, device=cfg.device, log_wandb=cfg.wandb.mode != 'disabled')
95
+
96
+ # Disable evaluation and saving
97
+ trainer.saving_iterations = []
98
+ trainer.evaluate_iterations = []
99
+
100
+ # Initialize
101
+ trainer.timer.start()
102
+ start_time = time.time()
103
+ trainer.init_with_corr(cfg.init_wC, roma_model=roma_model)
104
+ time_for_init = time.time()-start_time
105
+
106
+ viewpoint_cams = trainer.GS.scene.getTrainCameras()
107
+ path_cameras = generate_fully_smooth_cameras_with_tsp(existing_cameras=viewpoint_cams,
108
+ n_selected=6,
109
+ n_points_per_segment=30,
110
+ closed=False)
111
+ path_cameras = path_cameras + path_cameras[::-1]
112
+
113
+ path_renderings = []
114
+ idx = 0
115
+ # Visualize after init
116
+ for _ in range(120):
117
+ with torch.no_grad():
118
+ viewpoint_cam = path_cameras[idx]
119
+ idx = (idx + 1) % len(path_cameras)
120
+ render_pkg = trainer.GS(viewpoint_cam)
121
+ image = render_pkg["render"]
122
+ image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
123
+ image_np = (image_np * 255).astype(np.uint8)
124
+ path_renderings.append(put_text_on_image(img=image_np,
125
+ text=f"Init stage.\nTime:{time_for_init:.3f}s. "))
126
+ path_renderings = path_renderings + [put_text_on_image(img=image_np, text=f"Start fitting.\nTime:{time_for_init:.3f}s. ")]*30
127
+
128
+ # Train and save visualizations during training.
129
+ start_time = time.time()
130
+ for _ in range(int(num_steps//10)):
131
+ with torch.no_grad():
132
+ viewpoint_cam = path_cameras[idx]
133
+ idx = (idx + 1) % len(path_cameras)
134
+ render_pkg = trainer.GS(viewpoint_cam)
135
+ image = render_pkg["render"]
136
+ image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
137
+ image_np = (image_np * 255).astype(np.uint8)
138
+ path_renderings.append(put_text_on_image(
139
+ img=image_np,
140
+ text=f"Fitting stage.\nTime:{time_for_init + time.time()-start_time:.3f}s. "))
141
+
142
+ cfg.train.gs_epochs = 10
143
+ trainer.train(cfg.train)
144
+ print(f"Time elapsed: {(time_for_init + time.time()-start_time):.2f}s.")
145
+ # if (cfg.init_wC.use == False) and (time_for_init + time.time()-start_time) > 60:
146
+ # break
147
+ final_time = time.time()
148
+
149
+ # Add static frame. To highlight we're done
150
+ path_renderings += [put_text_on_image(
151
+ img=image_np, text=f"Done.\nTime:{time_for_init + final_time -start_time:.3f}s. ")]*30
152
+ # Final rendering at the end.
153
+ for _ in range(len(path_cameras)):
154
+ with torch.no_grad():
155
+ viewpoint_cam = path_cameras[idx]
156
+ idx = (idx + 1) % len(path_cameras)
157
+ render_pkg = trainer.GS(viewpoint_cam)
158
+ image = render_pkg["render"]
159
+ image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
160
+ image_np = (image_np * 255).astype(np.uint8)
161
+ path_renderings.append(put_text_on_image(img=image_np,
162
+ text=f"Final result.\nTime:{time_for_init + final_time -start_time:.3f}s. "))
163
+
164
+ trainer.save_model()
165
+ final_video_path = os.path.join(STATIC_FILE_SERVING_FOLDER, f"{scene_name}_final.mp4")
166
+ save_numpy_frames_as_mp4(frames=path_renderings, output_path=final_video_path, fps=30, center_crop=0.85)
167
+ MODEL_PATH = cfg.gs.dataset.model_path
168
+ ply_path = os.path.join(cfg.gs.dataset.model_path, f"point_cloud/iteration_{trainer.gs_step}/point_cloud.ply")
169
+ shutil.copy(ply_path, os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply"))
170
+
171
+ return final_video_path, ply_path
172
+
173
+ # Gradio Interface
174
+ def gradio_interface(input_path, num_ref_views, num_corrs, num_steps):
175
+ images, scene_dir = run_full_pipeline(input_path, num_ref_views, num_corrs, max_size=1024)
176
+ shutil.copytree(scene_dir, STATIC_FILE_SERVING_FOLDER+'/scene_colmaped', dirs_exist_ok=True)
177
+ (final_video_path, ply_path), log_output = capture_logs(run_training_pipeline,
178
+ scene_dir,
179
+ num_ref_views,
180
+ num_corrs,
181
+ num_steps
182
+ )
183
+ images_rgb = [img[:, :, ::-1] for img in images]
184
+ return images_rgb, final_video_path, scene_dir, ply_path, log_output
185
+
186
+ # Dummy Render Functions
187
+ @spaces.GPU(duration=60)
188
+ def render_all_views(scene_dir):
189
+ viewpoint_cams = trainer.GS.scene.getTrainCameras()
190
+ path_cameras = generate_fully_smooth_cameras_with_tsp(existing_cameras=viewpoint_cams,
191
+ n_selected=8,
192
+ n_points_per_segment=60,
193
+ closed=False)
194
+ path_cameras = path_cameras + path_cameras[::-1]
195
+
196
+ path_renderings = []
197
+ with torch.no_grad():
198
+ for viewpoint_cam in path_cameras:
199
+ render_pkg = trainer.GS(viewpoint_cam)
200
+ image = render_pkg["render"]
201
+ image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
202
+ image_np = (image_np * 255).astype(np.uint8)
203
+ path_renderings.append(image_np)
204
+ save_numpy_frames_as_mp4(frames=path_renderings,
205
+ output_path=os.path.join(STATIC_FILE_SERVING_FOLDER, "render_all_views.mp4"),
206
+ fps=30,
207
+ center_crop=0.85)
208
+
209
+ return os.path.join(STATIC_FILE_SERVING_FOLDER, "render_all_views.mp4")
210
+
211
+ @spaces.GPU(duration=60)
212
+ def render_circular_path(scene_dir):
213
+ viewpoint_cams = trainer.GS.scene.getTrainCameras()
214
+ path_cameras = generate_circular_camera_path(existing_cameras=viewpoint_cams,
215
+ N=240,
216
+ radius_scale=0.65,
217
+ d=0)
218
+
219
+ path_renderings = []
220
+ with torch.no_grad():
221
+ for viewpoint_cam in path_cameras:
222
+ render_pkg = trainer.GS(viewpoint_cam)
223
+ image = render_pkg["render"]
224
+ image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
225
+ image_np = (image_np * 255).astype(np.uint8)
226
+ path_renderings.append(image_np)
227
+ save_numpy_frames_as_mp4(frames=path_renderings,
228
+ output_path=os.path.join(STATIC_FILE_SERVING_FOLDER, "render_circular_path.mp4"),
229
+ fps=30,
230
+ center_crop=0.85)
231
+
232
+ return os.path.join(STATIC_FILE_SERVING_FOLDER, "render_circular_path.mp4")
233
+
234
+ # Download Functions
235
+ def download_cameras():
236
+ path = os.path.join(MODEL_PATH, "cameras.json")
237
+ return f"[📥 Download Cameras.json](file={path})"
238
+
239
+ def download_model():
240
+ path = os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply")
241
+ return f"[📥 Download Pretrained Model (.ply)](file={path})"
242
+
243
+ # Full pipeline helpers
244
+ def run_full_pipeline(input_path, num_ref_views, num_corrs, max_size=1024):
245
+ tmpdirname = tempfile.mkdtemp()
246
+ scene_dir = os.path.join(tmpdirname, "scene")
247
+ os.makedirs(scene_dir, exist_ok=True)
248
+
249
+ selected_frames = process_input(input_path, num_ref_views, scene_dir, max_size)
250
+ run_colmap_on_scene(scene_dir)
251
+
252
+ return selected_frames, scene_dir
253
+
254
+ # Preprocess Input
255
+ def process_input(input_path, num_ref_views, output_dir, max_size=1024):
256
+ if isinstance(input_path, (str, os.PathLike)):
257
+ if os.path.isdir(input_path):
258
+ frames = []
259
+ for img_file in sorted(os.listdir(input_path)):
260
+ if img_file.lower().endswith(('jpg', 'jpeg', 'png')):
261
+ img = Image.open(os.path.join(output_dir, img_file)).convert('RGB')
262
+ img.thumbnail((1024, 1024))
263
+ frames.append(np.array(img))
264
+ else:
265
+ frames = read_video_frames(video_input=input_path, max_size=max_size)
266
+ else:
267
+ frames = read_video_frames(video_input=input_path, max_size=max_size)
268
+
269
+ frames_scores = preprocess_frames(frames)
270
+ selected_frames_indices = select_optimal_frames(scores=frames_scores, k=min(num_ref_views, len(frames)))
271
+ selected_frames = [frames[frame_idx] for frame_idx in selected_frames_indices]
272
+
273
+ save_frames_to_scene_dir(frames=selected_frames, scene_dir=output_dir)
274
+ return selected_frames
275
+
276
+ @spaces.GPU(duration=150)
277
+ def preprocess_input(input_path, num_ref_views, max_size=1024):
278
+ tmpdirname = tempfile.mkdtemp()
279
+ scene_dir = os.path.join(tmpdirname, "scene")
280
+ os.makedirs(scene_dir, exist_ok=True)
281
+ selected_frames = process_input(input_path, num_ref_views, scene_dir, max_size)
282
+ run_colmap_on_scene(scene_dir)
283
+ return selected_frames, scene_dir
284
+
285
+ def start_training(scene_dir, num_ref_views, num_corrs, num_steps):
286
+ return capture_logs(run_training_pipeline, scene_dir, num_ref_views, num_corrs, num_steps)
287
+
288
+ # Gradio App
289
+ with gr.Blocks() as demo:
290
+ with gr.Row():
291
+ with gr.Column(scale=6):
292
+ gr.Markdown("""
293
+ ## <span style='font-size: 20px;'>📄 EDGS: Eliminating Densification for Efficient Convergence of 3DGS</span>
294
+ 🔗 <a href='https://compvis.github.io/EDGS' target='_blank'>Project Page</a>
295
+ """, elem_id="header")
296
+
297
+ gr.Markdown("""
298
+ ### <span style='font-size: 22px;'>🛠️ How to Use This Demo</span>
299
+
300
+ 1. Upload a **front-facing video** or **a folder of images** of a **static** scene.
301
+ 2. Use the sliders to configure the number of reference views, correspondences, and optimization steps.
302
+ 3. First press on preprocess Input to extract frames from video(for videos) and COLMAP frames.
303
+ 4.Then click **🚀 Start Reconstruction** to actually launch the reconstruction pipeline.
304
+ 5. Watch the training visualization and explore the 3D model.
305
+ ‼️ **If you see nothing in the 3D model viewer**, try rotating or zooming — sometimes the initial camera orientation is off.
306
+
307
+
308
+ ✅ Best for scenes with small camera motion.
309
+ ❗ For full 360° or large-scale scenes, we recommend the Colab version (see project page).
310
+ """, elem_id="quickstart")
311
+ scene_dir_state = gr.State()
312
+ ply_model_state = gr.State()
313
+ with gr.Row():
314
+ with gr.Column(scale=2):
315
+ input_file = gr.File(label="Upload Video or Images",
316
+ file_types=[".mp4", ".avi", ".mov", ".png", ".jpg", ".jpeg"],
317
+ file_count="multiple")
318
+ gr.Examples(
319
+ examples = [
320
+ [["assets/examples/video_bakery.mp4"]],
321
+ [["assets/examples/video_flowers.mp4"]],
322
+ [["assets/examples/video_fruits.mp4"]],
323
+ [["assets/examples/video_plant.mp4"]],
324
+ [["assets/examples/video_salad.mp4"]],
325
+ [["assets/examples/video_tram.mp4"]],
326
+ [["assets/examples/video_tulips.mp4"]]
327
+ ],
328
+ inputs=[input_file],
329
+ label="🎞️ ALternatively, try an Example Video",
330
+ examples_per_page=4
331
+ )
332
+ ref_slider = gr.Slider(4, 32, value=16, step=1, label="Number of Reference Views")
333
+ corr_slider = gr.Slider(5000, 30000, value=20000, step=1000, label="Correspondences per Reference View")
334
+ fit_steps_slider = gr.Slider(100, 5000, value=400, step=100, label="Number of optimization steps")
335
+ preprocess_button = gr.Button("📸 Preprocess Input")
336
+ start_button = gr.Button("🚀 Start Reconstruction", interactive=False)
337
+ gallery = gr.Gallery(label="Selected Reference Views", columns=4, height=300)
338
+
339
+ with gr.Column(scale=3):
340
+ gr.Markdown("### 🏋️ Training Visualization")
341
+ video_output = gr.Video(label="Training Video", autoplay=True)
342
+ render_all_views_button = gr.Button("🎥 Render All-Views Path")
343
+ render_circular_path_button = gr.Button("🎥 Render Circular Path")
344
+ rendered_video_output = gr.Video(label="Rendered Video", autoplay=True)
345
+ with gr.Column(scale=5):
346
+ gr.Markdown("### 🌐 Final 3D Model")
347
+ model3d_viewer = gr.Model3D(label="3D Model Viewer")
348
+
349
+ gr.Markdown("### 📦 Output Files")
350
+ with gr.Row(height=50):
351
+ with gr.Column():
352
+ #gr.Markdown(value=f"[📥 Download .ply](file/point_cloud_final.ply)")
353
+ download_cameras_button = gr.Button("📥 Download Cameras.json")
354
+ download_cameras_file = gr.File(label="📄 Cameras.json")
355
+ with gr.Column():
356
+ download_model_button = gr.Button("📥 Download Pretrained Model (.ply)")
357
+ download_model_file = gr.File(label="📄 Pretrained Model (.ply)")
358
+
359
+ log_output_box = gr.Textbox(label="🖥️ Log", lines=10, interactive=False)
360
+
361
+ def on_preprocess_click(input_file, num_ref_views):
362
+ images, scene_dir = preprocess_input(input_file, num_ref_views)
363
+ return gr.update(value=[x[...,::-1] for x in images]), scene_dir, gr.update(interactive=True)
364
+
365
+ def on_start_click(scene_dir, num_ref_views, num_corrs, num_steps):
366
+ (video_path, ply_path), logs = start_training(scene_dir, num_ref_views, num_corrs, num_steps)
367
+ return video_path, ply_path, logs
368
+
369
+ preprocess_button.click(
370
+ fn=on_preprocess_click,
371
+ inputs=[input_file, ref_slider],
372
+ outputs=[gallery, scene_dir_state, start_button]
373
+ )
374
+
375
+ start_button.click(
376
+ fn=on_start_click,
377
+ inputs=[scene_dir_state, ref_slider, corr_slider, fit_steps_slider],
378
+ outputs=[video_output, model3d_viewer, log_output_box]
379
+ )
380
+
381
+ render_all_views_button.click(fn=render_all_views, inputs=[scene_dir_state], outputs=[rendered_video_output])
382
+ render_circular_path_button.click(fn=render_circular_path, inputs=[scene_dir_state], outputs=[rendered_video_output])
383
+
384
+ download_cameras_button.click(fn=lambda: os.path.join(MODEL_PATH, "cameras.json"), inputs=[], outputs=[download_cameras_file])
385
+ download_model_button.click(fn=lambda: os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply"), inputs=[], outputs=[download_model_file])
386
+
387
+
388
+ gr.Markdown("""
389
+ ---
390
+ ### <span style='font-size: 20px;'>📖 Detailed Overview</span>
391
+
392
+ If you uploaded a video, it will be automatically cut into a smaller number of frames (default: 16).
393
+
394
+ The model pipeline:
395
+ 1. 🧠 Runs PyCOLMAP to estimate camera intrinsics & poses (~3–7 seconds for <16 images).
396
+ 2. 🔁 Computes 2D-2D correspondences between views. More correspondences generally improve quality.
397
+ 3. 🔧 Optimizes a 3D Gaussian Splatting model for several steps.
398
+
399
+ ### 🎥 Training Visualization
400
+ You will see a visualization of the entire training process in the "Training Video" pane.
401
+
402
+ ### 🌀 Rendering & 3D Model
403
+ - Render the scene from a circular path of novel views.
404
+ - Or from camera views close to the original input.
405
+
406
+ The 3D model is shown in the right viewer. You can explore it interactively:
407
+ - On PC: WASD keys, arrow keys, and mouse clicks
408
+ - On mobile: pan and pinch to zoom
409
+
410
+ 🕒 Note: the 3D viewer takes a few extra seconds (~5s) to display after training ends.
411
+
412
+ ---
413
+ Preloaded models coming soon. (TODO)
414
+ """, elem_id="details")
415
+
416
+
417
+
418
+
419
+ if __name__ == "__main__":
420
+ install_submodules()
421
+ from source.utils_aux import set_seed
422
+ from source.utils_preprocess import read_video_frames, preprocess_frames, select_optimal_frames, save_frames_to_scene_dir, run_colmap_on_scene
423
+ from source.trainer import EDGSTrainer
424
+ from source.visualization import generate_circular_camera_path, save_numpy_frames_as_mp4, generate_fully_smooth_cameras_with_tsp, put_text_on_image
425
+ # Init RoMA model:
426
+ sys.path.append('../submodules/RoMa')
427
+ from romatch import roma_outdoor, roma_indoor
428
+
429
+ roma_model = roma_indoor(device="cpu")
430
+ roma_model = roma_model.to("cuda")
431
+ roma_model.upsample_preds = False
432
+ roma_model.symmetric = False
433
+
434
+ demo.launch(share=True)
assets/examples/video_bakery.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b65c813f2ef9637579350e145fdceed333544d933278be5613c6d49468f4eab0
3
+ size 6362238
assets/examples/video_flowers.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c81a7c28e0d59bad38d5a45f6bfa83b80990d1fb78a06f82137e8b57ec38e62b
3
+ size 6466943
assets/examples/video_fruits.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac3b937e155d1965a314b478e940f6fd93e90371d3bf7d62d3225d840fe8e126
3
+ size 3356915
assets/examples/video_plant.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e67a40e62de9aacf0941d8cf33a2dd08d256fe60d0fc58f60426c251f9d8abd8
3
+ size 13023885
assets/examples/video_salad.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:367d86071a201c124383f0a076ce644b7f86b9c2fbce6aa595c4989ebf259bfb
3
+ size 8774427
assets/examples/video_tram.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:298c297155d3f52edcffcb6b6f9910c992b98ecfb93cfaf8fb64fb340aba1dae
3
+ size 4697915
assets/examples/video_tulips.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c19942319b8f7e33bb03cb4a39b11797fe38572bf7157af37471b7c8573fb495
3
+ size 7298210
assets/video_fruits_ours_full.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c5b113566d3a083b81360b549ed89f70d5e81739f83e182518f6906811311a2
3
+ size 14839197
configs/gs/base.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: source.networks.Warper3DGS
2
+
3
+ verbose: True
4
+ viewpoint_stack: !!null
5
+ sh_degree: 3
6
+
7
+ opt:
8
+ iterations: 30000
9
+ position_lr_init: 0.00016
10
+ position_lr_final: 1.6e-06
11
+ position_lr_delay_mult: 0.01
12
+ position_lr_max_steps: 30000
13
+ feature_lr: 0.0025
14
+ opacity_lr: 0.025
15
+ scaling_lr: 0.005
16
+ rotation_lr: 0.001
17
+ percent_dense: 0.01
18
+ lambda_dssim: 0.2
19
+ densification_interval: 100
20
+ opacity_reset_interval: 30000
21
+ densify_from_iter: 500
22
+ densify_until_iter: 15000
23
+ densify_grad_threshold: 0.0002
24
+ random_background: false
25
+ save_iterations: [3000, 7000, 15000, 30000]
26
+ batch_size: 64
27
+ exposure_lr_init: 0.01
28
+ exposure_lr_final: 0.0001
29
+ exposure_lr_delay_steps: 0
30
+ exposure_lr_delay_mult: 0.0
31
+
32
+ TRAIN_CAM_IDX_TO_LOG: 50
33
+ TEST_CAM_IDX_TO_LOG: 10
34
+
35
+ pipe:
36
+ convert_SHs_python: False
37
+ compute_cov3D_python: False
38
+ debug: False
39
+ antialiasing: False
40
+
41
+ dataset:
42
+ densify_until_iter: 15000
43
+ source_path: '' #path to dataset
44
+ model_path: '' #path to logs
45
+ images: images
46
+ resolution: -1
47
+ white_background: false
48
+ data_device: cuda
49
+ eval: false
50
+ depths: ""
51
+ train_test_exp: False
configs/train.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - gs: base
3
+ - _self_
4
+
5
+ seed: 228
6
+
7
+ wandb:
8
+ mode: "online" # "disabled" for no logging
9
+ entity: "3dcorrespondence"
10
+ project: "Adv3DGS"
11
+ group: null
12
+ name: null
13
+ tag: "debug"
14
+
15
+ train:
16
+ gs_epochs: 0 # number of 3dgs iterations
17
+ reduce_opacity: True
18
+ no_densify: False # if True, the model will not be densified
19
+ max_lr: True
20
+
21
+ load:
22
+ gs: null #path to 3dgs checkpoint
23
+ gs_step: null #number of iterations, e.g. 7000
24
+
25
+ device: "cuda:0"
26
+ verbose: true
27
+
28
+ init_wC:
29
+ use: True # use EDGS
30
+ matches_per_ref: 15_000 # number of matches per reference
31
+ num_refs: 180 # number of reference images
32
+ nns_per_ref: 3 # number of nearest neighbors per reference
33
+ scaling_factor: 0.001
34
+ proj_err_tolerance: 0.01
35
+ roma_model: "outdoors" # you can change this to "indoors" or "outdoors"
36
+ add_SfM_init : False
37
+
38
+
requirements.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu124
2
+ torch
3
+ torchvision
4
+ torchaudio
5
+
6
+
7
+ # Required libraries from pip
8
+ Pillow
9
+ huggingface_hub
10
+ einops
11
+ safetensors
12
+ sympy==1.13.1
13
+ wandb
14
+ hydra-core
15
+ tqdm
16
+ torchmetrics
17
+ lpips
18
+ matplotlib
19
+ rich
20
+ plyfile
21
+ imageio
22
+ imageio-ffmpeg
23
+ numpy==1.26.4 # Match conda-installed version
24
+ opencv-python
25
+ pycolmap
26
+ moviepy
27
+ plotly
28
+ scikit-learn
29
+ ffmpeg
30
+
31
+ https://huggingface.co/spaces/magistrkoljan/test/resolve/main/wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl?download=true
32
+ https://huggingface.co/spaces/magistrkoljan/test/resolve/main/wheels/simple_knn-0.0.0-cp310-cp310-linux_x86_64.whl?download=true
source/EDGS.code-workspace ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "folders": [
3
+ {
4
+ "path": ".."
5
+ },
6
+ {
7
+ "path": "../../../../.."
8
+ }
9
+ ],
10
+ "settings": {}
11
+ }
source/__init__.py ADDED
File without changes
source/corr_init.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('../')
3
+ sys.path.append("../submodules")
4
+ sys.path.append('../submodules/RoMa')
5
+
6
+ from matplotlib import pyplot as plt
7
+ from PIL import Image
8
+ import torch
9
+ import numpy as np
10
+
11
+ #from tqdm import tqdm_notebook as tqdm
12
+ from tqdm import tqdm
13
+ from scipy.cluster.vq import kmeans, vq
14
+ from scipy.spatial.distance import cdist
15
+
16
+ import torch.nn.functional as F
17
+ from romatch import roma_outdoor, roma_indoor
18
+ from utils.sh_utils import RGB2SH
19
+ from romatch.utils import get_tuple_transform_ops
20
+
21
+
22
+ def pairwise_distances(matrix):
23
+ """
24
+ Computes the pairwise Euclidean distances between all vectors in the input matrix.
25
+
26
+ Args:
27
+ matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
28
+
29
+ Returns:
30
+ torch.Tensor: Pairwise distance matrix of shape [N, N].
31
+ """
32
+ # Compute squared pairwise distances
33
+ squared_diff = torch.cdist(matrix, matrix, p=2)
34
+ return squared_diff
35
+
36
+
37
+ def k_closest_vectors(matrix, k):
38
+ """
39
+ Finds the k-closest vectors for each vector in the input matrix based on Euclidean distance.
40
+
41
+ Args:
42
+ matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
43
+ k (int): Number of closest vectors to return for each vector.
44
+
45
+ Returns:
46
+ torch.Tensor: Indices of the k-closest vectors for each vector, excluding the vector itself.
47
+ """
48
+ # Compute pairwise distances
49
+ distances = pairwise_distances(matrix)
50
+
51
+ # For each vector, sort distances and get the indices of the k-closest vectors (excluding itself)
52
+ # Set diagonal distances to infinity to exclude the vector itself from the nearest neighbors
53
+ distances.fill_diagonal_(float('inf'))
54
+
55
+ # Get the indices of the k smallest distances (k-closest vectors)
56
+ _, indices = torch.topk(distances, k, largest=False, dim=1)
57
+
58
+ return indices
59
+
60
+
61
+ def select_cameras_kmeans(cameras, K):
62
+ """
63
+ Selects K cameras from a set using K-means clustering.
64
+
65
+ Args:
66
+ cameras: NumPy array of shape (N, 16), representing N cameras with their 4x4 homogeneous matrices flattened.
67
+ K: Number of clusters (cameras to select).
68
+
69
+ Returns:
70
+ selected_indices: List of indices of the cameras closest to the cluster centers.
71
+ """
72
+ # Ensure input is a NumPy array
73
+ if not isinstance(cameras, np.ndarray):
74
+ cameras = np.asarray(cameras)
75
+
76
+ if cameras.shape[1] != 16:
77
+ raise ValueError("Each camera must have 16 values corresponding to a flattened 4x4 matrix.")
78
+
79
+ # Perform K-means clustering
80
+ cluster_centers, _ = kmeans(cameras, K)
81
+
82
+ # Assign each camera to a cluster and find distances to cluster centers
83
+ cluster_assignments, _ = vq(cameras, cluster_centers)
84
+
85
+ # Find the camera nearest to each cluster center
86
+ selected_indices = []
87
+ for k in range(K):
88
+ cluster_members = cameras[cluster_assignments == k]
89
+ distances = cdist([cluster_centers[k]], cluster_members)[0]
90
+ nearest_camera_idx = np.where(cluster_assignments == k)[0][np.argmin(distances)]
91
+ selected_indices.append(nearest_camera_idx)
92
+
93
+ return selected_indices
94
+
95
+
96
+ def compute_warp_and_confidence(viewpoint_cam1, viewpoint_cam2, roma_model, device="cuda", verbose=False, output_dict={}):
97
+ """
98
+ Computes the warp and confidence between two viewpoint cameras using the roma_model.
99
+
100
+ Args:
101
+ viewpoint_cam1: Source viewpoint camera.
102
+ viewpoint_cam2: Target viewpoint camera.
103
+ roma_model: Pre-trained Roma model for correspondence matching.
104
+ device: Device to run the computation on.
105
+ verbose: If True, displays the images.
106
+
107
+ Returns:
108
+ certainty: Confidence tensor.
109
+ warp: Warp tensor.
110
+ imB: Processed image B as numpy array.
111
+ """
112
+ # Prepare images
113
+ imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
114
+ imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
115
+ imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
116
+ imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
117
+
118
+ if verbose:
119
+ fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(16, 8))
120
+ cax1 = ax[0].imshow(imA)
121
+ ax[0].set_title("Image 1")
122
+ cax2 = ax[1].imshow(imB)
123
+ ax[1].set_title("Image 2")
124
+ fig.colorbar(cax1, ax=ax[0])
125
+ fig.colorbar(cax2, ax=ax[1])
126
+
127
+ for axis in ax:
128
+ axis.axis('off')
129
+ # Save the figure into the dictionary
130
+ output_dict[f'image_pair'] = fig
131
+
132
+ # Transform images
133
+ ws, hs = roma_model.w_resized, roma_model.h_resized
134
+ test_transform = get_tuple_transform_ops(resize=(hs, ws), normalize=True)
135
+ im_A, im_B = test_transform((imA, imB))
136
+ batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
137
+
138
+ # Forward pass through Roma model
139
+ corresps = roma_model.forward(batch) if not roma_model.symmetric else roma_model.forward_symmetric(batch)
140
+ finest_scale = 1
141
+ hs, ws = roma_model.upsample_res if roma_model.upsample_preds else (hs, ws)
142
+
143
+ # Process certainty and warp
144
+ certainty = corresps[finest_scale]["certainty"]
145
+ im_A_to_im_B = corresps[finest_scale]["flow"]
146
+ if roma_model.attenuate_cert:
147
+ low_res_certainty = F.interpolate(
148
+ corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
149
+ )
150
+ certainty -= 0.5 * low_res_certainty * (low_res_certainty < 0)
151
+
152
+ # Upsample predictions if needed
153
+ if roma_model.upsample_preds:
154
+ im_A_to_im_B = F.interpolate(
155
+ im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
156
+ )
157
+ certainty = F.interpolate(
158
+ certainty, size=(hs, ws), align_corners=False, mode="bilinear"
159
+ )
160
+
161
+ # Convert predictions to final format
162
+ im_A_to_im_B = im_A_to_im_B.permute(0, 2, 3, 1)
163
+ im_A_coords = torch.stack(torch.meshgrid(
164
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
165
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
166
+ indexing='ij'
167
+ ), dim=0).permute(1, 2, 0).unsqueeze(0).expand(im_A_to_im_B.size(0), -1, -1, -1)
168
+
169
+ warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
170
+ certainty = certainty.sigmoid()
171
+
172
+ return certainty[0, 0], warp[0], np.array(imB)
173
+
174
+
175
+ def resize_batch(tensors_3d, tensors_4d, target_shape):
176
+ """
177
+ Resizes a batch of tensors with shapes [B, H, W] and [B, H, W, 4] to the target spatial dimensions.
178
+
179
+ Args:
180
+ tensors_3d: Tensor of shape [B, H, W].
181
+ tensors_4d: Tensor of shape [B, H, W, 4].
182
+ target_shape: Tuple (target_H, target_W) specifying the target spatial dimensions.
183
+
184
+ Returns:
185
+ resized_tensors_3d: Tensor of shape [B, target_H, target_W].
186
+ resized_tensors_4d: Tensor of shape [B, target_H, target_W, 4].
187
+ """
188
+ target_H, target_W = target_shape
189
+
190
+ # Resize [B, H, W] tensor
191
+ resized_tensors_3d = F.interpolate(
192
+ tensors_3d.unsqueeze(1), size=(target_H, target_W), mode="bilinear", align_corners=False
193
+ ).squeeze(1)
194
+
195
+ # Resize [B, H, W, 4] tensor
196
+ B, _, _, C = tensors_4d.shape
197
+ resized_tensors_4d = F.interpolate(
198
+ tensors_4d.permute(0, 3, 1, 2), size=(target_H, target_W), mode="bilinear", align_corners=False
199
+ ).permute(0, 2, 3, 1)
200
+
201
+ return resized_tensors_3d, resized_tensors_4d
202
+
203
+
204
+ def aggregate_confidences_and_warps(viewpoint_stack, closest_indices, roma_model, source_idx, verbose=False, output_dict={}):
205
+ """
206
+ Aggregates confidences and warps by iterating over the nearest neighbors of the source viewpoint.
207
+
208
+ Args:
209
+ viewpoint_stack: Stack of viewpoint cameras.
210
+ closest_indices: Indices of the nearest neighbors for each viewpoint.
211
+ roma_model: Pre-trained Roma model.
212
+ source_idx: Index of the source viewpoint.
213
+ verbose: If True, displays intermediate results.
214
+
215
+ Returns:
216
+ certainties_max: Aggregated maximum confidences.
217
+ warps_max: Aggregated warps corresponding to maximum confidences.
218
+ certainties_max_idcs: Pixel-wise index of the image from which we taken the best matching.
219
+ imB_compound: List of the neighboring images.
220
+ """
221
+ certainties_all, warps_all, imB_compound = [], [], []
222
+
223
+ for nn in tqdm(closest_indices[source_idx]):
224
+
225
+ viewpoint_cam1 = viewpoint_stack[source_idx]
226
+ viewpoint_cam2 = viewpoint_stack[nn]
227
+
228
+ certainty, warp, imB = compute_warp_and_confidence(viewpoint_cam1, viewpoint_cam2, roma_model, verbose=verbose, output_dict=output_dict)
229
+ certainties_all.append(certainty)
230
+ warps_all.append(warp)
231
+ imB_compound.append(imB)
232
+
233
+ certainties_all = torch.stack(certainties_all, dim=0)
234
+ target_shape = imB_compound[0].shape[:2]
235
+ if verbose:
236
+ print("certainties_all.shape:", certainties_all.shape)
237
+ print("torch.stack(warps_all, dim=0).shape:", torch.stack(warps_all, dim=0).shape)
238
+ print("target_shape:", target_shape)
239
+
240
+ certainties_all_resized, warps_all_resized = resize_batch(certainties_all,
241
+ torch.stack(warps_all, dim=0),
242
+ target_shape
243
+ )
244
+
245
+ if verbose:
246
+ print("warps_all_resized.shape:", warps_all_resized.shape)
247
+ for n, cert in enumerate(certainties_all):
248
+ fig, ax = plt.subplots()
249
+ cax = ax.imshow(cert.cpu().numpy(), cmap='viridis')
250
+ fig.colorbar(cax, ax=ax)
251
+ ax.set_title("Pixel-wise Confidence")
252
+ output_dict[f'certainty_{n}'] = fig
253
+
254
+ for n, warp in enumerate(warps_all):
255
+ fig, ax = plt.subplots()
256
+ cax = ax.imshow(warp.cpu().numpy()[:, :, :3], cmap='viridis')
257
+ fig.colorbar(cax, ax=ax)
258
+ ax.set_title("Pixel-wise warp")
259
+ output_dict[f'warp_resized_{n}'] = fig
260
+
261
+ for n, cert in enumerate(certainties_all_resized):
262
+ fig, ax = plt.subplots()
263
+ cax = ax.imshow(cert.cpu().numpy(), cmap='viridis')
264
+ fig.colorbar(cax, ax=ax)
265
+ ax.set_title("Pixel-wise Confidence resized")
266
+ output_dict[f'certainty_resized_{n}'] = fig
267
+
268
+ for n, warp in enumerate(warps_all_resized):
269
+ fig, ax = plt.subplots()
270
+ cax = ax.imshow(warp.cpu().numpy()[:, :, :3], cmap='viridis')
271
+ fig.colorbar(cax, ax=ax)
272
+ ax.set_title("Pixel-wise warp resized")
273
+ output_dict[f'warp_resized_{n}'] = fig
274
+
275
+ certainties_max, certainties_max_idcs = torch.max(certainties_all_resized, dim=0)
276
+ H, W = certainties_max.shape
277
+
278
+ warps_max = warps_all_resized[certainties_max_idcs, torch.arange(H).unsqueeze(1), torch.arange(W)]
279
+
280
+ imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
281
+ imA = np.clip(imA * 255, 0, 255).astype(np.uint8)
282
+
283
+ return certainties_max, warps_max, certainties_max_idcs, imA, imB_compound, certainties_all_resized, warps_all_resized
284
+
285
+
286
+
287
+ def extract_keypoints_and_colors(imA, imB_compound, certainties_max, certainties_max_idcs, matches, roma_model,
288
+ verbose=False, output_dict={}):
289
+ """
290
+ Extracts keypoints and corresponding colors from the source image (imA) and multiple target images (imB_compound).
291
+
292
+ Args:
293
+ imA: Source image as a NumPy array (H_A, W_A, C).
294
+ imB_compound: List of target images as NumPy arrays [(H_B, W_B, C), ...].
295
+ certainties_max: Tensor of pixel-wise maximum confidences.
296
+ certainties_max_idcs: Tensor of pixel-wise indices for the best matches.
297
+ matches: Matches in normalized coordinates.
298
+ roma_model: Roma model instance for keypoint operations.
299
+ verbose: if to show intermediate outputs and visualize results
300
+
301
+ Returns:
302
+ kptsA_np: Keypoints in imA in normalized coordinates.
303
+ kptsB_np: Keypoints in imB in normalized coordinates.
304
+ kptsA_color: Colors of keypoints in imA.
305
+ kptsB_color: Colors of keypoints in imB based on certainties_max_idcs.
306
+ """
307
+ H_A, W_A, _ = imA.shape
308
+ H, W = certainties_max.shape
309
+
310
+ # Convert matches to pixel coordinates
311
+ kptsA, kptsB = roma_model.to_pixel_coordinates(
312
+ matches, W_A, H_A, H, W # W, H
313
+ )
314
+
315
+ kptsA_np = kptsA.detach().cpu().numpy()
316
+ kptsB_np = kptsB.detach().cpu().numpy()
317
+ kptsA_np = kptsA_np[:, [1, 0]]
318
+
319
+ if verbose:
320
+ fig, ax = plt.subplots(figsize=(12, 6))
321
+ cax = ax.imshow(imA)
322
+ ax.set_title("Reference image, imA")
323
+ output_dict[f'reference_image'] = fig
324
+
325
+ fig, ax = plt.subplots(figsize=(12, 6))
326
+ cax = ax.imshow(imB_compound[0])
327
+ ax.set_title("Image to compare to image, imB_compound")
328
+ output_dict[f'imB_compound'] = fig
329
+
330
+ fig, ax = plt.subplots(figsize=(12, 6))
331
+ cax = ax.imshow(np.flipud(imA))
332
+ cax = ax.scatter(kptsA_np[:, 0], H_A - kptsA_np[:, 1], s=.03)
333
+ ax.set_title("Keypoints in imA")
334
+ ax.set_xlim(0, W_A)
335
+ ax.set_ylim(0, H_A)
336
+ output_dict[f'kptsA'] = fig
337
+
338
+ fig, ax = plt.subplots(figsize=(12, 6))
339
+ cax = ax.imshow(np.flipud(imB_compound[0]))
340
+ cax = ax.scatter(kptsB_np[:, 0], H_A - kptsB_np[:, 1], s=.03)
341
+ ax.set_title("Keypoints in imB")
342
+ ax.set_xlim(0, W_A)
343
+ ax.set_ylim(0, H_A)
344
+ output_dict[f'kptsB'] = fig
345
+
346
+ # Keypoints are in format (row, column) so the first value is alwain in range [0;height] and second is in range[0;width]
347
+
348
+ kptsA_np = kptsA.detach().cpu().numpy()
349
+ kptsB_np = kptsB.detach().cpu().numpy()
350
+
351
+ # Extract colors for keypoints in imA (vectorized)
352
+ # New experimental version
353
+ kptsA_x = np.round(kptsA_np[:, 0] / 1.).astype(int)
354
+ kptsA_y = np.round(kptsA_np[:, 1] / 1.).astype(int)
355
+ kptsA_color = imA[np.clip(kptsA_x, 0, H - 1), np.clip(kptsA_y, 0, W - 1)]
356
+
357
+ # Create a composite image from imB_compound
358
+ imB_compound_np = np.stack(imB_compound, axis=0)
359
+ H_B, W_B, _ = imB_compound[0].shape
360
+
361
+ # Extract colors for keypoints in imB using certainties_max_idcs
362
+ imB_np = imB_compound_np[
363
+ certainties_max_idcs.detach().cpu().numpy(),
364
+ np.arange(H).reshape(-1, 1),
365
+ np.arange(W)
366
+ ]
367
+
368
+ if verbose:
369
+ print("imB_np.shape:", imB_np.shape)
370
+ print("imB_np:", imB_np)
371
+ fig, ax = plt.subplots(figsize=(12, 6))
372
+ cax = ax.imshow(np.flipud(imB_np))
373
+ cax = ax.scatter(kptsB_np[:, 0], H_A - kptsB_np[:, 1], s=.03)
374
+ ax.set_title("np.flipud(imB_np[0]")
375
+ ax.set_xlim(0, W_A)
376
+ ax.set_ylim(0, H_A)
377
+ output_dict[f'np.flipud(imB_np[0]'] = fig
378
+
379
+
380
+ kptsB_x = np.round(kptsB_np[:, 0]).astype(int)
381
+ kptsB_y = np.round(kptsB_np[:, 1]).astype(int)
382
+
383
+ certainties_max_idcs_np = certainties_max_idcs.detach().cpu().numpy()
384
+ kptsB_proj_matrices_idx = certainties_max_idcs_np[np.clip(kptsA_x, 0, H - 1), np.clip(kptsA_y, 0, W - 1)]
385
+ kptsB_color = imB_compound_np[kptsB_proj_matrices_idx, np.clip(kptsB_y, 0, H - 1), np.clip(kptsB_x, 0, W - 1)]
386
+
387
+ # Normalize keypoints in both images
388
+ kptsA_np[:, 0] = kptsA_np[:, 0] / H * 2.0 - 1.0
389
+ kptsA_np[:, 1] = kptsA_np[:, 1] / W * 2.0 - 1.0
390
+ kptsB_np[:, 0] = kptsB_np[:, 0] / W_B * 2.0 - 1.0
391
+ kptsB_np[:, 1] = kptsB_np[:, 1] / H_B * 2.0 - 1.0
392
+
393
+ return kptsA_np[:, [1, 0]], kptsB_np, kptsB_proj_matrices_idx, kptsA_color, kptsB_color
394
+
395
+ def prepare_tensor(input_array, device):
396
+ """
397
+ Converts an input array to a torch tensor, clones it, and detaches it for safe computation.
398
+ Args:
399
+ input_array (array-like): The input array to convert.
400
+ device (str or torch.device): The device to move the tensor to.
401
+ Returns:
402
+ torch.Tensor: A detached tensor clone of the input array on the specified device.
403
+ """
404
+ if not isinstance(input_array, torch.Tensor):
405
+ return torch.tensor(input_array, dtype=torch.float32).to(device).clone().detach()
406
+ return input_array.clone().detach().to(device).to(torch.float32)
407
+
408
+ def triangulate_points(P1, P2, k1_x, k1_y, k2_x, k2_y, device="cuda"):
409
+ """
410
+ Solves for a batch of 3D points given batches of projection matrices and corresponding image points.
411
+
412
+ Parameters:
413
+ - P1, P2: Tensors of projection matrices of size (batch_size, 4, 4) or (4, 4)
414
+ - k1_x, k1_y: Tensors of shape (batch_size,)
415
+ - k2_x, k2_y: Tensors of shape (batch_size,)
416
+
417
+ Returns:
418
+ - X: A tensor containing the 3D homogeneous coordinates, shape (batch_size, 4)
419
+ """
420
+ EPS = 1e-4
421
+ # Ensure inputs are tensors
422
+
423
+ P1 = prepare_tensor(P1, device)
424
+ P2 = prepare_tensor(P2, device)
425
+ k1_x = prepare_tensor(k1_x, device)
426
+ k1_y = prepare_tensor(k1_y, device)
427
+ k2_x = prepare_tensor(k2_x, device)
428
+ k2_y = prepare_tensor(k2_y, device)
429
+ batch_size = k1_x.shape[0]
430
+
431
+ # Expand P1 and P2 if they are not batched
432
+ if P1.ndim == 2:
433
+ P1 = P1.unsqueeze(0).expand(batch_size, -1, -1)
434
+ if P2.ndim == 2:
435
+ P2 = P2.unsqueeze(0).expand(batch_size, -1, -1)
436
+
437
+ # Extract columns from P1 and P2
438
+ P1_0 = P1[:, :, 0] # Shape: (batch_size, 4)
439
+ P1_1 = P1[:, :, 1]
440
+ P1_2 = P1[:, :, 2]
441
+
442
+ P2_0 = P2[:, :, 0]
443
+ P2_1 = P2[:, :, 1]
444
+ P2_2 = P2[:, :, 2]
445
+
446
+ # Reshape kx and ky to (batch_size, 1)
447
+ k1_x = k1_x.view(-1, 1)
448
+ k1_y = k1_y.view(-1, 1)
449
+ k2_x = k2_x.view(-1, 1)
450
+ k2_y = k2_y.view(-1, 1)
451
+
452
+ # Construct the equations for each batch
453
+ # For camera 1
454
+ A1 = P1_0 - k1_x * P1_2 # Shape: (batch_size, 4)
455
+ A2 = P1_1 - k1_y * P1_2
456
+ # For camera 2
457
+ A3 = P2_0 - k2_x * P2_2
458
+ A4 = P2_1 - k2_y * P2_2
459
+
460
+ # Stack the equations
461
+ A = torch.stack([A1, A2, A3, A4], dim=1) # Shape: (batch_size, 4, 4)
462
+
463
+ # Right-hand side (constants)
464
+ b = -A[:, :, 3] # Shape: (batch_size, 4)
465
+ A_reduced = A[:, :, :3] # Coefficients of x, y, z
466
+
467
+ # Solve using torch.linalg.lstsq (supports batching)
468
+ X_xyz = torch.linalg.lstsq(A_reduced, b.unsqueeze(2)).solution.squeeze(2) # Shape: (batch_size, 3)
469
+
470
+ # Append 1 to get homogeneous coordinates
471
+ ones = torch.ones((batch_size, 1), dtype=torch.float32, device=X_xyz.device)
472
+ X = torch.cat([X_xyz, ones], dim=1) # Shape: (batch_size, 4)
473
+
474
+ # Now compute the errors of projections.
475
+ seeked_splats_proj1 = (X.unsqueeze(1) @ P1).squeeze(1)
476
+ seeked_splats_proj1 = seeked_splats_proj1 / (EPS + seeked_splats_proj1[:, [3]])
477
+ seeked_splats_proj2 = (X.unsqueeze(1) @ P2).squeeze(1)
478
+ seeked_splats_proj2 = seeked_splats_proj2 / (EPS + seeked_splats_proj2[:, [3]])
479
+ proj1_target = torch.concat([k1_x, k1_y], dim=1)
480
+ proj2_target = torch.concat([k2_x, k2_y], dim=1)
481
+ errors_proj1 = torch.abs(seeked_splats_proj1[:, :2] - proj1_target).sum(1).detach().cpu().numpy()
482
+ errors_proj2 = torch.abs(seeked_splats_proj2[:, :2] - proj2_target).sum(1).detach().cpu().numpy()
483
+
484
+ return X, errors_proj1, errors_proj2
485
+
486
+
487
+
488
+ def select_best_keypoints(
489
+ NNs_triangulated_points, NNs_errors_proj1, NNs_errors_proj2, device="cuda"):
490
+ """
491
+ From all the points fitted to keypoints and corresponding colors from the source image (imA) and multiple target images (imB_compound).
492
+
493
+ Args:
494
+ NNs_triangulated_points: torch tensor with keypoints coordinates (num_nns, num_points, dim). dim can be arbitrary,
495
+ usually 3 or 4(for homogeneous representation).
496
+ NNs_errors_proj1: numpy array with projection error of the estimated keypoint on the reference frame (num_nns, num_points).
497
+ NNs_errors_proj2: numpy array with projection error of the estimated keypoint on the neighbor frame (num_nns, num_points).
498
+ Returns:
499
+ selected_keypoints: keypoints with the best score.
500
+ """
501
+
502
+ NNs_errors_proj = np.maximum(NNs_errors_proj1, NNs_errors_proj2)
503
+
504
+ # Convert indices to PyTorch tensor
505
+ indices = torch.from_numpy(np.argmin(NNs_errors_proj, axis=0)).long().to(device)
506
+
507
+ # Create index tensor for the second dimension
508
+ n_indices = torch.arange(NNs_triangulated_points.shape[1]).long().to(device)
509
+
510
+ # Use advanced indexing to select elements
511
+ NNs_triangulated_points_selected = NNs_triangulated_points[indices, n_indices, :] # Shape: [N, k]
512
+
513
+ return NNs_triangulated_points_selected, np.min(NNs_errors_proj, axis=0)
514
+
515
+
516
+
517
+ def init_gaussians_with_corr(gaussians, scene, cfg, device, verbose = False, roma_model=None):
518
+ """
519
+ For a given input gaussians and a scene we instantiate a RoMa model(change to indoors if necessary) and process scene
520
+ training frames to extract correspondences. Those are used to initialize gaussians
521
+ Args:
522
+ gaussians: object gaussians of the class GaussianModel that we need to enrich with gaussians.
523
+ scene: object of the Scene class.
524
+ cfg: configuration. Use init_wC
525
+ Returns:
526
+ gaussians: inplace transforms object gaussians of the class GaussianModel.
527
+
528
+ """
529
+ if roma_model is None:
530
+ if cfg.roma_model == "indoors":
531
+ roma_model = roma_indoor(device=device)
532
+ else:
533
+ roma_model = roma_outdoor(device=device)
534
+ roma_model.upsample_preds = False
535
+ roma_model.symmetric = False
536
+ M = cfg.matches_per_ref
537
+ upper_thresh = roma_model.sample_thresh
538
+ scaling_factor = cfg.scaling_factor
539
+ expansion_factor = 1
540
+ keypoint_fit_error_tolerance = cfg.proj_err_tolerance
541
+ visualizations = {}
542
+ viewpoint_stack = scene.getTrainCameras().copy()
543
+ NUM_REFERENCE_FRAMES = min(cfg.num_refs, len(viewpoint_stack))
544
+ NUM_NNS_PER_REFERENCE = min(cfg.nns_per_ref , len(viewpoint_stack))
545
+ # Select cameras using K-means
546
+ viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
547
+
548
+ selected_indices = select_cameras_kmeans(cameras=viewpoint_cam_all.detach().cpu().numpy(), K=NUM_REFERENCE_FRAMES)
549
+ selected_indices = sorted(selected_indices)
550
+
551
+
552
+ # Find the k-closest vectors for each vector
553
+ viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
554
+ closest_indices = k_closest_vectors(viewpoint_cam_all, NUM_NNS_PER_REFERENCE)
555
+ if verbose: print("Indices of k-closest vectors for each vector:\n", closest_indices)
556
+
557
+ closest_indices_selected = closest_indices[:, :].detach().cpu().numpy()
558
+
559
+ all_new_xyz = []
560
+ all_new_features_dc = []
561
+ all_new_features_rest = []
562
+ all_new_opacities = []
563
+ all_new_scaling = []
564
+ all_new_rotation = []
565
+
566
+ # Run roma_model.match once to kinda initialize the model
567
+ with torch.no_grad():
568
+ viewpoint_cam1 = viewpoint_stack[0]
569
+ viewpoint_cam2 = viewpoint_stack[1]
570
+ imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
571
+ imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
572
+ imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
573
+ imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
574
+ warp, certainty_warp = roma_model.match(imA, imB, device=device)
575
+ print("Once run full roma_model.match warp.shape:", warp.shape)
576
+ print("Once run full roma_model.match certainty_warp.shape:", certainty_warp.shape)
577
+ del warp, certainty_warp
578
+ torch.cuda.empty_cache()
579
+
580
+ for source_idx in tqdm(sorted(selected_indices)):
581
+ # 1. Compute keypoints and warping for all the neigboring views
582
+ with torch.no_grad():
583
+ # Call the aggregation function to get imA and imB_compound
584
+ certainties_max, warps_max, certainties_max_idcs, imA, imB_compound, certainties_all, warps_all = aggregate_confidences_and_warps(
585
+ viewpoint_stack=viewpoint_stack,
586
+ closest_indices=closest_indices_selected,
587
+ roma_model=roma_model,
588
+ source_idx=source_idx,
589
+ verbose=verbose, output_dict=visualizations
590
+ )
591
+
592
+
593
+ # Triangulate keypoints
594
+ with torch.no_grad():
595
+ matches = warps_max
596
+ certainty = certainties_max
597
+ certainty = certainty.clone()
598
+ certainty[certainty > upper_thresh] = 1
599
+ matches, certainty = (
600
+ matches.reshape(-1, 4),
601
+ certainty.reshape(-1),
602
+ )
603
+
604
+ # Select based on certainty elements with high confidence. These are basically all of
605
+ # kptsA_np.
606
+ good_samples = torch.multinomial(certainty,
607
+ num_samples=min(expansion_factor * M, len(certainty)),
608
+ replacement=False)
609
+
610
+ certainties_max, warps_max, certainties_max_idcs, imA, imB_compound, certainties_all, warps_all
611
+ reference_image_dict = {
612
+ "ref_image": imA,
613
+ "NNs_images": imB_compound,
614
+ "certainties_all": certainties_all,
615
+ "warps_all": warps_all,
616
+ "triangulated_points": [],
617
+ "triangulated_points_errors_proj1": [],
618
+ "triangulated_points_errors_proj2": []
619
+
620
+ }
621
+ with torch.no_grad():
622
+ for NN_idx in tqdm(range(len(warps_all))):
623
+ matches_NN = warps_all[NN_idx].reshape(-1, 4)[good_samples]
624
+
625
+ # Extract keypoints and colors
626
+ kptsA_np, kptsB_np, kptsB_proj_matrices_idcs, kptsA_color, kptsB_color = extract_keypoints_and_colors(
627
+ imA, imB_compound, certainties_max, certainties_max_idcs, matches_NN, roma_model
628
+ )
629
+
630
+ proj_matrices_A = viewpoint_stack[source_idx].full_proj_transform
631
+ proj_matrices_B = viewpoint_stack[closest_indices_selected[source_idx, NN_idx]].full_proj_transform
632
+ triangulated_points, triangulated_points_errors_proj1, triangulated_points_errors_proj2 = triangulate_points(
633
+ P1=torch.stack([proj_matrices_A] * M, axis=0),
634
+ P2=torch.stack([proj_matrices_B] * M, axis=0),
635
+ k1_x=kptsA_np[:M, 0], k1_y=kptsA_np[:M, 1],
636
+ k2_x=kptsB_np[:M, 0], k2_y=kptsB_np[:M, 1])
637
+
638
+ reference_image_dict["triangulated_points"].append(triangulated_points)
639
+ reference_image_dict["triangulated_points_errors_proj1"].append(triangulated_points_errors_proj1)
640
+ reference_image_dict["triangulated_points_errors_proj2"].append(triangulated_points_errors_proj2)
641
+
642
+ with torch.no_grad():
643
+ NNs_triangulated_points_selected, NNs_triangulated_points_selected_proj_errors = select_best_keypoints(
644
+ NNs_triangulated_points=torch.stack(reference_image_dict["triangulated_points"], dim=0),
645
+ NNs_errors_proj1=np.stack(reference_image_dict["triangulated_points_errors_proj1"], axis=0),
646
+ NNs_errors_proj2=np.stack(reference_image_dict["triangulated_points_errors_proj2"], axis=0))
647
+
648
+ # 4. Save as gaussians
649
+ viewpoint_cam1 = viewpoint_stack[source_idx]
650
+ N = len(NNs_triangulated_points_selected)
651
+ with torch.no_grad():
652
+ new_xyz = NNs_triangulated_points_selected[:, :-1]
653
+ all_new_xyz.append(new_xyz) # seeked_splats
654
+ all_new_features_dc.append(RGB2SH(torch.tensor(kptsA_color.astype(np.float32) / 255.)).unsqueeze(1))
655
+ all_new_features_rest.append(torch.stack([gaussians._features_rest[-1].clone().detach() * 0.] * N, dim=0))
656
+ # new version that sets points with large error invisible
657
+ # TODO: remove those points instead. However it doesn't affect the performance.
658
+ mask_bad_points = torch.tensor(
659
+ NNs_triangulated_points_selected_proj_errors > keypoint_fit_error_tolerance,
660
+ dtype=torch.float32).unsqueeze(1).to(device)
661
+ all_new_opacities.append(torch.stack([gaussians._opacity[-1].clone().detach()] * N, dim=0) * 0. - mask_bad_points * (1e1))
662
+
663
+ dist_points_to_cam1 = torch.linalg.norm(viewpoint_cam1.camera_center.clone().detach() - new_xyz,
664
+ dim=1, ord=2)
665
+ #all_new_scaling.append(torch.log(((dist_points_to_cam1) / 1. * scaling_factor).unsqueeze(1).repeat(1, 3)))
666
+ all_new_scaling.append(gaussians.scaling_inverse_activation((dist_points_to_cam1 * scaling_factor).unsqueeze(1).repeat(1, 3)))
667
+ all_new_rotation.append(torch.stack([gaussians._rotation[-1].clone().detach()] * N, dim=0))
668
+
669
+ all_new_xyz = torch.cat(all_new_xyz, dim=0)
670
+ all_new_features_dc = torch.cat(all_new_features_dc, dim=0)
671
+ new_tmp_radii = torch.zeros(all_new_xyz.shape[0])
672
+ prune_mask = torch.ones(all_new_xyz.shape[0], dtype=torch.bool)
673
+
674
+ gaussians.densification_postfix(all_new_xyz[prune_mask].to(device),
675
+ all_new_features_dc[prune_mask].to(device),
676
+ torch.cat(all_new_features_rest, dim=0)[prune_mask].to(device),
677
+ torch.cat(all_new_opacities, dim=0)[prune_mask].to(device),
678
+ torch.cat(all_new_scaling, dim=0)[prune_mask].to(device),
679
+ torch.cat(all_new_rotation, dim=0)[prune_mask].to(device),
680
+ new_tmp_radii[prune_mask].to(device))
681
+
682
+ return viewpoint_stack, closest_indices_selected, visualizations
source/corr_init_new.py ADDED
@@ -0,0 +1,904 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('../')
3
+ sys.path.append("../submodules")
4
+ sys.path.append('../submodules/RoMa')
5
+
6
+ from matplotlib import pyplot as plt
7
+ from PIL import Image
8
+ import torch
9
+ import numpy as np
10
+
11
+ #from tqdm import tqdm_notebook as tqdm
12
+ from tqdm import tqdm
13
+ from scipy.cluster.vq import kmeans, vq
14
+ from scipy.spatial.distance import cdist
15
+
16
+ import torch.nn.functional as F
17
+ from romatch import roma_outdoor, roma_indoor
18
+ from utils.sh_utils import RGB2SH
19
+ from romatch.utils import get_tuple_transform_ops
20
+
21
+
22
+ def pairwise_distances(matrix):
23
+ """
24
+ Computes the pairwise Euclidean distances between all vectors in the input matrix.
25
+
26
+ Args:
27
+ matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
28
+
29
+ Returns:
30
+ torch.Tensor: Pairwise distance matrix of shape [N, N].
31
+ """
32
+ # Compute squared pairwise distances
33
+ squared_diff = torch.cdist(matrix, matrix, p=2)
34
+ return squared_diff
35
+
36
+
37
+ def k_closest_vectors(matrix, k):
38
+ """
39
+ Finds the k-closest vectors for each vector in the input matrix based on Euclidean distance.
40
+
41
+ Args:
42
+ matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
43
+ k (int): Number of closest vectors to return for each vector.
44
+
45
+ Returns:
46
+ torch.Tensor: Indices of the k-closest vectors for each vector, excluding the vector itself.
47
+ """
48
+ # Compute pairwise distances
49
+ distances = pairwise_distances(matrix)
50
+
51
+ # For each vector, sort distances and get the indices of the k-closest vectors (excluding itself)
52
+ # Set diagonal distances to infinity to exclude the vector itself from the nearest neighbors
53
+ distances.fill_diagonal_(float('inf'))
54
+
55
+ # Get the indices of the k smallest distances (k-closest vectors)
56
+ _, indices = torch.topk(distances, k, largest=False, dim=1)
57
+
58
+ return indices
59
+
60
+
61
+ def select_cameras_kmeans(cameras, K):
62
+ """
63
+ Selects K cameras from a set using K-means clustering.
64
+
65
+ Args:
66
+ cameras: NumPy array of shape (N, 16), representing N cameras with their 4x4 homogeneous matrices flattened.
67
+ K: Number of clusters (cameras to select).
68
+
69
+ Returns:
70
+ selected_indices: List of indices of the cameras closest to the cluster centers.
71
+ """
72
+ # Ensure input is a NumPy array
73
+ if not isinstance(cameras, np.ndarray):
74
+ cameras = np.asarray(cameras)
75
+
76
+ if cameras.shape[1] != 16:
77
+ raise ValueError("Each camera must have 16 values corresponding to a flattened 4x4 matrix.")
78
+
79
+ # Perform K-means clustering
80
+ cluster_centers, _ = kmeans(cameras, K)
81
+
82
+ # Assign each camera to a cluster and find distances to cluster centers
83
+ cluster_assignments, _ = vq(cameras, cluster_centers)
84
+
85
+ # Find the camera nearest to each cluster center
86
+ selected_indices = []
87
+ for k in range(K):
88
+ cluster_members = cameras[cluster_assignments == k]
89
+ distances = cdist([cluster_centers[k]], cluster_members)[0]
90
+ nearest_camera_idx = np.where(cluster_assignments == k)[0][np.argmin(distances)]
91
+ selected_indices.append(nearest_camera_idx)
92
+
93
+ return selected_indices
94
+
95
+
96
+ def compute_warp_and_confidence(viewpoint_cam1, viewpoint_cam2, roma_model, device="cuda", verbose=False, output_dict={}):
97
+ """
98
+ Computes the warp and confidence between two viewpoint cameras using the roma_model.
99
+
100
+ Args:
101
+ viewpoint_cam1: Source viewpoint camera.
102
+ viewpoint_cam2: Target viewpoint camera.
103
+ roma_model: Pre-trained Roma model for correspondence matching.
104
+ device: Device to run the computation on.
105
+ verbose: If True, displays the images.
106
+
107
+ Returns:
108
+ certainty: Confidence tensor.
109
+ warp: Warp tensor.
110
+ imB: Processed image B as numpy array.
111
+ """
112
+ # Prepare images
113
+ imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
114
+ imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
115
+ imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
116
+ imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
117
+
118
+ if verbose:
119
+ fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(16, 8))
120
+ cax1 = ax[0].imshow(imA)
121
+ ax[0].set_title("Image 1")
122
+ cax2 = ax[1].imshow(imB)
123
+ ax[1].set_title("Image 2")
124
+ fig.colorbar(cax1, ax=ax[0])
125
+ fig.colorbar(cax2, ax=ax[1])
126
+
127
+ for axis in ax:
128
+ axis.axis('off')
129
+ # Save the figure into the dictionary
130
+ output_dict[f'image_pair'] = fig
131
+
132
+ # Transform images
133
+ ws, hs = roma_model.w_resized, roma_model.h_resized
134
+ test_transform = get_tuple_transform_ops(resize=(hs, ws), normalize=True)
135
+ im_A, im_B = test_transform((imA, imB))
136
+ batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
137
+
138
+ # Forward pass through Roma model
139
+ corresps = roma_model.forward(batch) if not roma_model.symmetric else roma_model.forward_symmetric(batch)
140
+ finest_scale = 1
141
+ hs, ws = roma_model.upsample_res if roma_model.upsample_preds else (hs, ws)
142
+
143
+ # Process certainty and warp
144
+ certainty = corresps[finest_scale]["certainty"]
145
+ im_A_to_im_B = corresps[finest_scale]["flow"]
146
+ if roma_model.attenuate_cert:
147
+ low_res_certainty = F.interpolate(
148
+ corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
149
+ )
150
+ certainty -= 0.5 * low_res_certainty * (low_res_certainty < 0)
151
+
152
+ # Upsample predictions if needed
153
+ if roma_model.upsample_preds:
154
+ im_A_to_im_B = F.interpolate(
155
+ im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
156
+ )
157
+ certainty = F.interpolate(
158
+ certainty, size=(hs, ws), align_corners=False, mode="bilinear"
159
+ )
160
+
161
+ # Convert predictions to final format
162
+ im_A_to_im_B = im_A_to_im_B.permute(0, 2, 3, 1)
163
+ im_A_coords = torch.stack(torch.meshgrid(
164
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
165
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
166
+ indexing='ij'
167
+ ), dim=0).permute(1, 2, 0).unsqueeze(0).expand(im_A_to_im_B.size(0), -1, -1, -1)
168
+
169
+ warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
170
+ certainty = certainty.sigmoid()
171
+
172
+ return certainty[0, 0], warp[0], np.array(imB)
173
+
174
+
175
+ def resize_batch(tensors_3d, tensors_4d, target_shape):
176
+ """
177
+ Resizes a batch of tensors with shapes [B, H, W] and [B, H, W, 4] to the target spatial dimensions.
178
+
179
+ Args:
180
+ tensors_3d: Tensor of shape [B, H, W].
181
+ tensors_4d: Tensor of shape [B, H, W, 4].
182
+ target_shape: Tuple (target_H, target_W) specifying the target spatial dimensions.
183
+
184
+ Returns:
185
+ resized_tensors_3d: Tensor of shape [B, target_H, target_W].
186
+ resized_tensors_4d: Tensor of shape [B, target_H, target_W, 4].
187
+ """
188
+ target_H, target_W = target_shape
189
+
190
+ # Resize [B, H, W] tensor
191
+ resized_tensors_3d = F.interpolate(
192
+ tensors_3d.unsqueeze(1), size=(target_H, target_W), mode="bilinear", align_corners=False
193
+ ).squeeze(1)
194
+
195
+ # Resize [B, H, W, 4] tensor
196
+ B, _, _, C = tensors_4d.shape
197
+ resized_tensors_4d = F.interpolate(
198
+ tensors_4d.permute(0, 3, 1, 2), size=(target_H, target_W), mode="bilinear", align_corners=False
199
+ ).permute(0, 2, 3, 1)
200
+
201
+ return resized_tensors_3d, resized_tensors_4d
202
+
203
+
204
+ def aggregate_confidences_and_warps(viewpoint_stack, closest_indices, roma_model, source_idx, verbose=False, output_dict={}):
205
+ """
206
+ Aggregates confidences and warps by iterating over the nearest neighbors of the source viewpoint.
207
+
208
+ Args:
209
+ viewpoint_stack: Stack of viewpoint cameras.
210
+ closest_indices: Indices of the nearest neighbors for each viewpoint.
211
+ roma_model: Pre-trained Roma model.
212
+ source_idx: Index of the source viewpoint.
213
+ verbose: If True, displays intermediate results.
214
+
215
+ Returns:
216
+ certainties_max: Aggregated maximum confidences.
217
+ warps_max: Aggregated warps corresponding to maximum confidences.
218
+ certainties_max_idcs: Pixel-wise index of the image from which we taken the best matching.
219
+ imB_compound: List of the neighboring images.
220
+ """
221
+ certainties_all, warps_all, imB_compound = [], [], []
222
+
223
+ for nn in tqdm(closest_indices[source_idx]):
224
+
225
+ viewpoint_cam1 = viewpoint_stack[source_idx]
226
+ viewpoint_cam2 = viewpoint_stack[nn]
227
+
228
+ certainty, warp, imB = compute_warp_and_confidence(viewpoint_cam1, viewpoint_cam2, roma_model, verbose=verbose, output_dict=output_dict)
229
+ certainties_all.append(certainty)
230
+ warps_all.append(warp)
231
+ imB_compound.append(imB)
232
+
233
+ certainties_all = torch.stack(certainties_all, dim=0)
234
+ target_shape = imB_compound[0].shape[:2]
235
+ if verbose:
236
+ print("certainties_all.shape:", certainties_all.shape)
237
+ print("torch.stack(warps_all, dim=0).shape:", torch.stack(warps_all, dim=0).shape)
238
+ print("target_shape:", target_shape)
239
+
240
+ certainties_all_resized, warps_all_resized = resize_batch(certainties_all,
241
+ torch.stack(warps_all, dim=0),
242
+ target_shape
243
+ )
244
+
245
+ if verbose:
246
+ print("warps_all_resized.shape:", warps_all_resized.shape)
247
+ for n, cert in enumerate(certainties_all):
248
+ fig, ax = plt.subplots()
249
+ cax = ax.imshow(cert.cpu().numpy(), cmap='viridis')
250
+ fig.colorbar(cax, ax=ax)
251
+ ax.set_title("Pixel-wise Confidence")
252
+ output_dict[f'certainty_{n}'] = fig
253
+
254
+ for n, warp in enumerate(warps_all):
255
+ fig, ax = plt.subplots()
256
+ cax = ax.imshow(warp.cpu().numpy()[:, :, :3], cmap='viridis')
257
+ fig.colorbar(cax, ax=ax)
258
+ ax.set_title("Pixel-wise warp")
259
+ output_dict[f'warp_resized_{n}'] = fig
260
+
261
+ for n, cert in enumerate(certainties_all_resized):
262
+ fig, ax = plt.subplots()
263
+ cax = ax.imshow(cert.cpu().numpy(), cmap='viridis')
264
+ fig.colorbar(cax, ax=ax)
265
+ ax.set_title("Pixel-wise Confidence resized")
266
+ output_dict[f'certainty_resized_{n}'] = fig
267
+
268
+ for n, warp in enumerate(warps_all_resized):
269
+ fig, ax = plt.subplots()
270
+ cax = ax.imshow(warp.cpu().numpy()[:, :, :3], cmap='viridis')
271
+ fig.colorbar(cax, ax=ax)
272
+ ax.set_title("Pixel-wise warp resized")
273
+ output_dict[f'warp_resized_{n}'] = fig
274
+
275
+ certainties_max, certainties_max_idcs = torch.max(certainties_all_resized, dim=0)
276
+ H, W = certainties_max.shape
277
+
278
+ warps_max = warps_all_resized[certainties_max_idcs, torch.arange(H).unsqueeze(1), torch.arange(W)]
279
+
280
+
281
+ return certainties_max, warps_max, certainties_max_idcs, imB_compound, certainties_all_resized, warps_all_resized
282
+
283
+
284
+
285
+ def extract_keypoints_and_colors(imA, imB_compound, certainties_max, certainties_max_idcs, matches, roma_model,
286
+ verbose=False, output_dict={}):
287
+ """
288
+ Extracts keypoints and corresponding colors from the source image (imA) and multiple target images (imB_compound).
289
+
290
+ Args:
291
+ imA: Source image as a NumPy array (H_A, W_A, C).
292
+ imB_compound: List of target images as NumPy arrays [(H_B, W_B, C), ...].
293
+ certainties_max: Tensor of pixel-wise maximum confidences.
294
+ certainties_max_idcs: Tensor of pixel-wise indices for the best matches.
295
+ matches: Matches in normalized coordinates.
296
+ roma_model: Roma model instance for keypoint operations.
297
+ verbose: if to show intermediate outputs and visualize results
298
+
299
+ Returns:
300
+ kptsA_np: Keypoints in imA in normalized coordinates.
301
+ kptsB_np: Keypoints in imB in normalized coordinates.
302
+ kptsA_color: Colors of keypoints in imA.
303
+ kptsB_color: Colors of keypoints in imB based on certainties_max_idcs.
304
+ """
305
+ H_A, W_A, _ = imA.shape
306
+ H, W = certainties_max.shape
307
+
308
+ # Convert matches to pixel coordinates
309
+ kptsA, kptsB = roma_model.to_pixel_coordinates(
310
+ matches, W_A, H_A, H, W # W, H
311
+ )
312
+
313
+ kptsA_np = kptsA.detach().cpu().numpy()
314
+ kptsB_np = kptsB.detach().cpu().numpy()
315
+ kptsA_np = kptsA_np[:, [1, 0]]
316
+
317
+ if verbose:
318
+ fig, ax = plt.subplots(figsize=(12, 6))
319
+ cax = ax.imshow(imA)
320
+ ax.set_title("Reference image, imA")
321
+ output_dict[f'reference_image'] = fig
322
+
323
+ fig, ax = plt.subplots(figsize=(12, 6))
324
+ cax = ax.imshow(imB_compound[0])
325
+ ax.set_title("Image to compare to image, imB_compound")
326
+ output_dict[f'imB_compound'] = fig
327
+
328
+ fig, ax = plt.subplots(figsize=(12, 6))
329
+ cax = ax.imshow(np.flipud(imA))
330
+ cax = ax.scatter(kptsA_np[:, 0], H_A - kptsA_np[:, 1], s=.03)
331
+ ax.set_title("Keypoints in imA")
332
+ ax.set_xlim(0, W_A)
333
+ ax.set_ylim(0, H_A)
334
+ output_dict[f'kptsA'] = fig
335
+
336
+ fig, ax = plt.subplots(figsize=(12, 6))
337
+ cax = ax.imshow(np.flipud(imB_compound[0]))
338
+ cax = ax.scatter(kptsB_np[:, 0], H_A - kptsB_np[:, 1], s=.03)
339
+ ax.set_title("Keypoints in imB")
340
+ ax.set_xlim(0, W_A)
341
+ ax.set_ylim(0, H_A)
342
+ output_dict[f'kptsB'] = fig
343
+
344
+ # Keypoints are in format (row, column) so the first value is alwain in range [0;height] and second is in range[0;width]
345
+
346
+ kptsA_np = kptsA.detach().cpu().numpy()
347
+ kptsB_np = kptsB.detach().cpu().numpy()
348
+
349
+ # Extract colors for keypoints in imA (vectorized)
350
+ # New experimental version
351
+ kptsA_x = np.round(kptsA_np[:, 0] / 1.).astype(int)
352
+ kptsA_y = np.round(kptsA_np[:, 1] / 1.).astype(int)
353
+ kptsA_color = imA[np.clip(kptsA_x, 0, H - 1), np.clip(kptsA_y, 0, W - 1)]
354
+
355
+ # Create a composite image from imB_compound
356
+ imB_compound_np = np.stack(imB_compound, axis=0)
357
+ H_B, W_B, _ = imB_compound[0].shape
358
+
359
+ # Extract colors for keypoints in imB using certainties_max_idcs
360
+ imB_np = imB_compound_np[
361
+ certainties_max_idcs.detach().cpu().numpy(),
362
+ np.arange(H).reshape(-1, 1),
363
+ np.arange(W)
364
+ ]
365
+
366
+ if verbose:
367
+ print("imB_np.shape:", imB_np.shape)
368
+ print("imB_np:", imB_np)
369
+ fig, ax = plt.subplots(figsize=(12, 6))
370
+ cax = ax.imshow(np.flipud(imB_np))
371
+ cax = ax.scatter(kptsB_np[:, 0], H_A - kptsB_np[:, 1], s=.03)
372
+ ax.set_title("np.flipud(imB_np[0]")
373
+ ax.set_xlim(0, W_A)
374
+ ax.set_ylim(0, H_A)
375
+ output_dict[f'np.flipud(imB_np[0]'] = fig
376
+
377
+
378
+ kptsB_x = np.round(kptsB_np[:, 0]).astype(int)
379
+ kptsB_y = np.round(kptsB_np[:, 1]).astype(int)
380
+
381
+ certainties_max_idcs_np = certainties_max_idcs.detach().cpu().numpy()
382
+ kptsB_proj_matrices_idx = certainties_max_idcs_np[np.clip(kptsA_x, 0, H - 1), np.clip(kptsA_y, 0, W - 1)]
383
+ kptsB_color = imB_compound_np[kptsB_proj_matrices_idx, np.clip(kptsB_y, 0, H - 1), np.clip(kptsB_x, 0, W - 1)]
384
+
385
+ # Normalize keypoints in both images
386
+ kptsA_np[:, 0] = kptsA_np[:, 0] / H * 2.0 - 1.0
387
+ kptsA_np[:, 1] = kptsA_np[:, 1] / W * 2.0 - 1.0
388
+ kptsB_np[:, 0] = kptsB_np[:, 0] / W_B * 2.0 - 1.0
389
+ kptsB_np[:, 1] = kptsB_np[:, 1] / H_B * 2.0 - 1.0
390
+
391
+ return kptsA_np[:, [1, 0]], kptsB_np, kptsB_proj_matrices_idx, kptsA_color, kptsB_color
392
+
393
+ def prepare_tensor(input_array, device):
394
+ """
395
+ Converts an input array to a torch tensor, clones it, and detaches it for safe computation.
396
+ Args:
397
+ input_array (array-like): The input array to convert.
398
+ device (str or torch.device): The device to move the tensor to.
399
+ Returns:
400
+ torch.Tensor: A detached tensor clone of the input array on the specified device.
401
+ """
402
+ if not isinstance(input_array, torch.Tensor):
403
+ return torch.tensor(input_array, dtype=torch.float32).to(device).clone().detach()
404
+ return input_array.clone().detach().to(device).to(torch.float32)
405
+
406
+ def triangulate_points(P1, P2, k1_x, k1_y, k2_x, k2_y, device="cuda"):
407
+ """
408
+ Solves for a batch of 3D points given batches of projection matrices and corresponding image points.
409
+
410
+ Parameters:
411
+ - P1, P2: Tensors of projection matrices of size (batch_size, 4, 4) or (4, 4)
412
+ - k1_x, k1_y: Tensors of shape (batch_size,)
413
+ - k2_x, k2_y: Tensors of shape (batch_size,)
414
+
415
+ Returns:
416
+ - X: A tensor containing the 3D homogeneous coordinates, shape (batch_size, 4)
417
+ """
418
+ EPS = 1e-4
419
+ # Ensure inputs are tensors
420
+
421
+ P1 = prepare_tensor(P1, device)
422
+ P2 = prepare_tensor(P2, device)
423
+ k1_x = prepare_tensor(k1_x, device)
424
+ k1_y = prepare_tensor(k1_y, device)
425
+ k2_x = prepare_tensor(k2_x, device)
426
+ k2_y = prepare_tensor(k2_y, device)
427
+ batch_size = k1_x.shape[0]
428
+
429
+ # Expand P1 and P2 if they are not batched
430
+ if P1.ndim == 2:
431
+ P1 = P1.unsqueeze(0).expand(batch_size, -1, -1)
432
+ if P2.ndim == 2:
433
+ P2 = P2.unsqueeze(0).expand(batch_size, -1, -1)
434
+
435
+ # Extract columns from P1 and P2
436
+ P1_0 = P1[:, :, 0] # Shape: (batch_size, 4)
437
+ P1_1 = P1[:, :, 1]
438
+ P1_2 = P1[:, :, 2]
439
+
440
+ P2_0 = P2[:, :, 0]
441
+ P2_1 = P2[:, :, 1]
442
+ P2_2 = P2[:, :, 2]
443
+
444
+ # Reshape kx and ky to (batch_size, 1)
445
+ k1_x = k1_x.view(-1, 1)
446
+ k1_y = k1_y.view(-1, 1)
447
+ k2_x = k2_x.view(-1, 1)
448
+ k2_y = k2_y.view(-1, 1)
449
+
450
+ # Construct the equations for each batch
451
+ # For camera 1
452
+ A1 = P1_0 - k1_x * P1_2 # Shape: (batch_size, 4)
453
+ A2 = P1_1 - k1_y * P1_2
454
+ # For camera 2
455
+ A3 = P2_0 - k2_x * P2_2
456
+ A4 = P2_1 - k2_y * P2_2
457
+
458
+ # Stack the equations
459
+ A = torch.stack([A1, A2, A3, A4], dim=1) # Shape: (batch_size, 4, 4)
460
+
461
+ # Right-hand side (constants)
462
+ b = -A[:, :, 3] # Shape: (batch_size, 4)
463
+ A_reduced = A[:, :, :3] # Coefficients of x, y, z
464
+
465
+ # Solve using torch.linalg.lstsq (supports batching)
466
+ X_xyz = torch.linalg.lstsq(A_reduced, b.unsqueeze(2)).solution.squeeze(2) # Shape: (batch_size, 3)
467
+
468
+ # Append 1 to get homogeneous coordinates
469
+ ones = torch.ones((batch_size, 1), dtype=torch.float32, device=X_xyz.device)
470
+ X = torch.cat([X_xyz, ones], dim=1) # Shape: (batch_size, 4)
471
+
472
+ # Now compute the errors of projections.
473
+ seeked_splats_proj1 = (X.unsqueeze(1) @ P1).squeeze(1)
474
+ seeked_splats_proj1 = seeked_splats_proj1 / (EPS + seeked_splats_proj1[:, [3]])
475
+ seeked_splats_proj2 = (X.unsqueeze(1) @ P2).squeeze(1)
476
+ seeked_splats_proj2 = seeked_splats_proj2 / (EPS + seeked_splats_proj2[:, [3]])
477
+ proj1_target = torch.concat([k1_x, k1_y], dim=1)
478
+ proj2_target = torch.concat([k2_x, k2_y], dim=1)
479
+ errors_proj1 = torch.abs(seeked_splats_proj1[:, :2] - proj1_target).sum(1).detach().cpu().numpy()
480
+ errors_proj2 = torch.abs(seeked_splats_proj2[:, :2] - proj2_target).sum(1).detach().cpu().numpy()
481
+
482
+ return X, errors_proj1, errors_proj2
483
+
484
+
485
+
486
+ def select_best_keypoints(
487
+ NNs_triangulated_points, NNs_errors_proj1, NNs_errors_proj2, device="cuda"):
488
+ """
489
+ From all the points fitted to keypoints and corresponding colors from the source image (imA) and multiple target images (imB_compound).
490
+
491
+ Args:
492
+ NNs_triangulated_points: torch tensor with keypoints coordinates (num_nns, num_points, dim). dim can be arbitrary,
493
+ usually 3 or 4(for homogeneous representation).
494
+ NNs_errors_proj1: numpy array with projection error of the estimated keypoint on the reference frame (num_nns, num_points).
495
+ NNs_errors_proj2: numpy array with projection error of the estimated keypoint on the neighbor frame (num_nns, num_points).
496
+ Returns:
497
+ selected_keypoints: keypoints with the best score.
498
+ """
499
+
500
+ NNs_errors_proj = np.maximum(NNs_errors_proj1, NNs_errors_proj2)
501
+
502
+ # Convert indices to PyTorch tensor
503
+ indices = torch.from_numpy(np.argmin(NNs_errors_proj, axis=0)).long().to(device)
504
+
505
+ # Create index tensor for the second dimension
506
+ n_indices = torch.arange(NNs_triangulated_points.shape[1]).long().to(device)
507
+
508
+ # Use advanced indexing to select elements
509
+ NNs_triangulated_points_selected = NNs_triangulated_points[indices, n_indices, :] # Shape: [N, k]
510
+
511
+ return NNs_triangulated_points_selected, np.min(NNs_errors_proj, axis=0)
512
+
513
+
514
+
515
+ import time
516
+ from collections import defaultdict
517
+ from tqdm import tqdm
518
+
519
+ # def init_gaussians_with_corr_profiled(gaussians, scene, cfg, device, verbose=False, roma_model=None):
520
+ # timings = defaultdict(list) # To accumulate timings
521
+
522
+ # if roma_model is None:
523
+ # if cfg.roma_model == "indoors":
524
+ # roma_model = roma_indoor(device=device)
525
+ # else:
526
+ # roma_model = roma_outdoor(device=device)
527
+ # roma_model.upsample_preds = False
528
+ # roma_model.symmetric = False
529
+
530
+ # M = cfg.matches_per_ref
531
+ # upper_thresh = roma_model.sample_thresh
532
+ # scaling_factor = cfg.scaling_factor
533
+ # expansion_factor = 1
534
+ # keypoint_fit_error_tolerance = cfg.proj_err_tolerance
535
+ # visualizations = {}
536
+ # viewpoint_stack = scene.getTrainCameras().copy()
537
+ # NUM_REFERENCE_FRAMES = min(cfg.num_refs, len(viewpoint_stack))
538
+ # NUM_NNS_PER_REFERENCE = min(cfg.nns_per_ref, len(viewpoint_stack))
539
+
540
+ # viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
541
+
542
+ # selected_indices = select_cameras_kmeans(cameras=viewpoint_cam_all.detach().cpu().numpy(), K=NUM_REFERENCE_FRAMES)
543
+ # selected_indices = sorted(selected_indices)
544
+
545
+ # viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
546
+ # closest_indices = k_closest_vectors(viewpoint_cam_all, NUM_NNS_PER_REFERENCE)
547
+ # closest_indices_selected = closest_indices[:, :].detach().cpu().numpy()
548
+
549
+ # all_new_xyz = []
550
+ # all_new_features_dc = []
551
+ # all_new_features_rest = []
552
+ # all_new_opacities = []
553
+ # all_new_scaling = []
554
+ # all_new_rotation = []
555
+
556
+ # # Dummy first pass to initialize model
557
+ # with torch.no_grad():
558
+ # viewpoint_cam1 = viewpoint_stack[0]
559
+ # viewpoint_cam2 = viewpoint_stack[1]
560
+ # imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
561
+ # imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
562
+ # imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
563
+ # imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
564
+ # warp, certainty_warp = roma_model.match(imA, imB, device=device)
565
+ # del warp, certainty_warp
566
+ # torch.cuda.empty_cache()
567
+
568
+ # # Main Loop over source_idx
569
+ # for source_idx in tqdm(sorted(selected_indices), desc="Profiling source frames"):
570
+
571
+ # # =================== Step 1: Aggregate Confidences and Warps ===================
572
+ # start = time.time()
573
+ # viewpoint_cam1 = viewpoint_stack[source_idx]
574
+ # viewpoint_cam2 = viewpoint_stack[closest_indices_selected[source_idx,0]]
575
+ # imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
576
+ # imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
577
+ # imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
578
+ # imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
579
+ # warp, certainty_warp = roma_model.match(imA, imB, device=device)
580
+
581
+ # certainties_max, warps_max, certainties_max_idcs, imB_compound, certainties_all, warps_all = aggregate_confidences_and_warps(
582
+ # viewpoint_stack=viewpoint_stack,
583
+ # closest_indices=closest_indices_selected,
584
+ # roma_model=roma_model,
585
+ # source_idx=source_idx,
586
+ # verbose=verbose,
587
+ # output_dict=visualizations
588
+ # )
589
+
590
+ # certainties_max = certainty_warp
591
+ # with torch.no_grad():
592
+ # warps_all = warps.unsqueeze(0)
593
+
594
+ # timings['aggregation_warp_certainty'].append(time.time() - start)
595
+
596
+ # # =================== Step 2: Good Samples Selection ===================
597
+ # start = time.time()
598
+ # certainty = certainties_max.reshape(-1).clone()
599
+ # certainty[certainty > upper_thresh] = 1
600
+ # good_samples = torch.multinomial(certainty, num_samples=min(expansion_factor * M, len(certainty)), replacement=False)
601
+ # timings['good_samples_selection'].append(time.time() - start)
602
+
603
+ # # =================== Step 3: Triangulate Keypoints for Each NN ===================
604
+ # reference_image_dict = {
605
+ # "triangulated_points": [],
606
+ # "triangulated_points_errors_proj1": [],
607
+ # "triangulated_points_errors_proj2": []
608
+ # }
609
+
610
+ # start = time.time()
611
+ # for NN_idx in range(len(warps_all)):
612
+ # matches_NN = warps_all[NN_idx].reshape(-1, 4)[good_samples]
613
+
614
+ # # Extract keypoints and colors
615
+ # kptsA_np, kptsB_np, kptsB_proj_matrices_idcs, kptsA_color, kptsB_color = extract_keypoints_and_colors(
616
+ # imA, imB_compound, certainties_max, certainties_max_idcs, matches_NN, roma_model
617
+ # )
618
+
619
+ # proj_matrices_A = viewpoint_stack[source_idx].full_proj_transform
620
+ # proj_matrices_B = viewpoint_stack[closest_indices_selected[source_idx, NN_idx]].full_proj_transform
621
+ # triangulated_points, triangulated_points_errors_proj1, triangulated_points_errors_proj2 = triangulate_points(
622
+ # P1=torch.stack([proj_matrices_A] * M, axis=0),
623
+ # P2=torch.stack([proj_matrices_B] * M, axis=0),
624
+ # k1_x=kptsA_np[:M, 0], k1_y=kptsA_np[:M, 1],
625
+ # k2_x=kptsB_np[:M, 0], k2_y=kptsB_np[:M, 1])
626
+
627
+ # reference_image_dict["triangulated_points"].append(triangulated_points)
628
+ # reference_image_dict["triangulated_points_errors_proj1"].append(triangulated_points_errors_proj1)
629
+ # reference_image_dict["triangulated_points_errors_proj2"].append(triangulated_points_errors_proj2)
630
+ # timings['triangulation_per_NN'].append(time.time() - start)
631
+
632
+ # # =================== Step 4: Select Best Triangulated Points ===================
633
+ # start = time.time()
634
+ # NNs_triangulated_points_selected, NNs_triangulated_points_selected_proj_errors = select_best_keypoints(
635
+ # NNs_triangulated_points=torch.stack(reference_image_dict["triangulated_points"], dim=0),
636
+ # NNs_errors_proj1=np.stack(reference_image_dict["triangulated_points_errors_proj1"], axis=0),
637
+ # NNs_errors_proj2=np.stack(reference_image_dict["triangulated_points_errors_proj2"], axis=0))
638
+ # timings['select_best_keypoints'].append(time.time() - start)
639
+
640
+ # # =================== Step 5: Create New Gaussians ===================
641
+ # start = time.time()
642
+ # viewpoint_cam1 = viewpoint_stack[source_idx]
643
+ # N = len(NNs_triangulated_points_selected)
644
+ # new_xyz = NNs_triangulated_points_selected[:, :-1]
645
+ # all_new_xyz.append(new_xyz)
646
+ # all_new_features_dc.append(RGB2SH(torch.tensor(kptsA_color.astype(np.float32) / 255.)).unsqueeze(1))
647
+ # all_new_features_rest.append(torch.stack([gaussians._features_rest[-1].clone().detach() * 0.] * N, dim=0))
648
+
649
+ # mask_bad_points = torch.tensor(
650
+ # NNs_triangulated_points_selected_proj_errors > keypoint_fit_error_tolerance,
651
+ # dtype=torch.float32).unsqueeze(1).to(device)
652
+
653
+ # all_new_opacities.append(torch.stack([gaussians._opacity[-1].clone().detach()] * N, dim=0) * 0. - mask_bad_points * (1e1))
654
+
655
+ # dist_points_to_cam1 = torch.linalg.norm(viewpoint_cam1.camera_center.clone().detach() - new_xyz, dim=1, ord=2)
656
+ # all_new_scaling.append(gaussians.scaling_inverse_activation((dist_points_to_cam1 * scaling_factor).unsqueeze(1).repeat(1, 3)))
657
+ # all_new_rotation.append(torch.stack([gaussians._rotation[-1].clone().detach()] * N, dim=0))
658
+ # timings['save_gaussians'].append(time.time() - start)
659
+
660
+ # # =================== Final Densification Postfix ===================
661
+ # start = time.time()
662
+ # all_new_xyz = torch.cat(all_new_xyz, dim=0)
663
+ # all_new_features_dc = torch.cat(all_new_features_dc, dim=0)
664
+ # new_tmp_radii = torch.zeros(all_new_xyz.shape[0])
665
+ # prune_mask = torch.ones(all_new_xyz.shape[0], dtype=torch.bool)
666
+
667
+ # gaussians.densification_postfix(
668
+ # all_new_xyz[prune_mask].to(device),
669
+ # all_new_features_dc[prune_mask].to(device),
670
+ # torch.cat(all_new_features_rest, dim=0)[prune_mask].to(device),
671
+ # torch.cat(all_new_opacities, dim=0)[prune_mask].to(device),
672
+ # torch.cat(all_new_scaling, dim=0)[prune_mask].to(device),
673
+ # torch.cat(all_new_rotation, dim=0)[prune_mask].to(device),
674
+ # new_tmp_radii[prune_mask].to(device)
675
+ # )
676
+ # timings['final_densification_postfix'].append(time.time() - start)
677
+
678
+ # # =================== Print Profiling Results ===================
679
+ # print("\n=== Profiling Summary (average per frame) ===")
680
+ # for key, times in timings.items():
681
+ # print(f"{key:35s}: {sum(times) / len(times):.4f} sec (total {sum(times):.2f} sec)")
682
+
683
+ # return viewpoint_stack, closest_indices_selected, visualizations
684
+
685
+
686
+
687
+ def extract_keypoints_and_colors_single(imA, imB, matches, roma_model, verbose=False, output_dict={}):
688
+ """
689
+ Extracts keypoints and corresponding colors from a source image (imA) and a single target image (imB).
690
+
691
+ Args:
692
+ imA: Source image as a NumPy array (H_A, W_A, C).
693
+ imB: Target image as a NumPy array (H_B, W_B, C).
694
+ matches: Matches in normalized coordinates (torch.Tensor).
695
+ roma_model: Roma model instance for keypoint operations.
696
+ verbose: If True, outputs intermediate visualizations.
697
+ Returns:
698
+ kptsA_np: Keypoints in imA (normalized).
699
+ kptsB_np: Keypoints in imB (normalized).
700
+ kptsA_color: Colors of keypoints in imA.
701
+ kptsB_color: Colors of keypoints in imB.
702
+ """
703
+ H_A, W_A, _ = imA.shape
704
+ H_B, W_B, _ = imB.shape
705
+
706
+ # Convert matches to pixel coordinates
707
+ # Matches format: (B, 4) = (x1_norm, y1_norm, x2_norm, y2_norm)
708
+ kptsA = matches[:, :2] # [N, 2]
709
+ kptsB = matches[:, 2:] # [N, 2]
710
+
711
+ # Scale normalized coordinates [-1,1] to pixel coordinates
712
+ kptsA_pix = torch.zeros_like(kptsA)
713
+ kptsB_pix = torch.zeros_like(kptsB)
714
+
715
+ # Important! [Normalized to pixel space]
716
+ kptsA_pix[:, 0] = (kptsA[:, 0] + 1) * (W_A - 1) / 2
717
+ kptsA_pix[:, 1] = (kptsA[:, 1] + 1) * (H_A - 1) / 2
718
+
719
+ kptsB_pix[:, 0] = (kptsB[:, 0] + 1) * (W_B - 1) / 2
720
+ kptsB_pix[:, 1] = (kptsB[:, 1] + 1) * (H_B - 1) / 2
721
+
722
+ kptsA_np = kptsA_pix.detach().cpu().numpy()
723
+ kptsB_np = kptsB_pix.detach().cpu().numpy()
724
+
725
+ # Extract colors
726
+ kptsA_x = np.round(kptsA_np[:, 0]).astype(int)
727
+ kptsA_y = np.round(kptsA_np[:, 1]).astype(int)
728
+ kptsB_x = np.round(kptsB_np[:, 0]).astype(int)
729
+ kptsB_y = np.round(kptsB_np[:, 1]).astype(int)
730
+
731
+ kptsA_color = imA[np.clip(kptsA_y, 0, H_A-1), np.clip(kptsA_x, 0, W_A-1)]
732
+ kptsB_color = imB[np.clip(kptsB_y, 0, H_B-1), np.clip(kptsB_x, 0, W_B-1)]
733
+
734
+ # Normalize keypoints into [-1, 1] for downstream triangulation
735
+ kptsA_np_norm = np.zeros_like(kptsA_np)
736
+ kptsB_np_norm = np.zeros_like(kptsB_np)
737
+
738
+ kptsA_np_norm[:, 0] = kptsA_np[:, 0] / (W_A - 1) * 2.0 - 1.0
739
+ kptsA_np_norm[:, 1] = kptsA_np[:, 1] / (H_A - 1) * 2.0 - 1.0
740
+
741
+ kptsB_np_norm[:, 0] = kptsB_np[:, 0] / (W_B - 1) * 2.0 - 1.0
742
+ kptsB_np_norm[:, 1] = kptsB_np[:, 1] / (H_B - 1) * 2.0 - 1.0
743
+
744
+ return kptsA_np_norm, kptsB_np_norm, kptsA_color, kptsB_color
745
+
746
+
747
+
748
+ def init_gaussians_with_corr_profiled(gaussians, scene, cfg, device, verbose=False, roma_model=None):
749
+ timings = defaultdict(list)
750
+
751
+ if roma_model is None:
752
+ if cfg.roma_model == "indoors":
753
+ roma_model = roma_indoor(device=device)
754
+ else:
755
+ roma_model = roma_outdoor(device=device)
756
+ roma_model.upsample_preds = False
757
+ roma_model.symmetric = False
758
+
759
+ M = cfg.matches_per_ref
760
+ upper_thresh = roma_model.sample_thresh
761
+ scaling_factor = cfg.scaling_factor
762
+ expansion_factor = 1
763
+ keypoint_fit_error_tolerance = cfg.proj_err_tolerance
764
+ visualizations = {}
765
+ viewpoint_stack = scene.getTrainCameras().copy()
766
+ NUM_REFERENCE_FRAMES = min(cfg.num_refs, len(viewpoint_stack))
767
+ NUM_NNS_PER_REFERENCE = 1 # Only ONE neighbor now!
768
+
769
+ viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
770
+
771
+ selected_indices = select_cameras_kmeans(cameras=viewpoint_cam_all.detach().cpu().numpy(), K=NUM_REFERENCE_FRAMES)
772
+ selected_indices = sorted(selected_indices)
773
+
774
+ viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
775
+ closest_indices = k_closest_vectors(viewpoint_cam_all, NUM_NNS_PER_REFERENCE)
776
+ closest_indices_selected = closest_indices[:, :].detach().cpu().numpy()
777
+
778
+ all_new_xyz = []
779
+ all_new_features_dc = []
780
+ all_new_features_rest = []
781
+ all_new_opacities = []
782
+ all_new_scaling = []
783
+ all_new_rotation = []
784
+
785
+ # Dummy first pass to initialize model
786
+ with torch.no_grad():
787
+ viewpoint_cam1 = viewpoint_stack[0]
788
+ viewpoint_cam2 = viewpoint_stack[1]
789
+ imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
790
+ imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
791
+ imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
792
+ imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
793
+ warp, certainty_warp = roma_model.match(imA, imB, device=device)
794
+ del warp, certainty_warp
795
+ torch.cuda.empty_cache()
796
+
797
+ # Main Loop over source_idx
798
+ for source_idx in tqdm(sorted(selected_indices), desc="Profiling source frames"):
799
+
800
+ # =================== Step 1: Compute Warp and Certainty ===================
801
+ start = time.time()
802
+ viewpoint_cam1 = viewpoint_stack[source_idx]
803
+ NNs=closest_indices_selected.shape[1]
804
+ viewpoint_cam2 = viewpoint_stack[closest_indices_selected[source_idx, np.random.randint(NNs)]]
805
+ imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
806
+ imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
807
+ imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
808
+ imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
809
+ warp, certainty_warp = roma_model.match(imA, imB, device=device)
810
+
811
+ certainties_max = certainty_warp # New manual sampling
812
+ timings['aggregation_warp_certainty'].append(time.time() - start)
813
+
814
+ # =================== Step 2: Good Samples Selection ===================
815
+ start = time.time()
816
+ certainty = certainties_max.reshape(-1).clone()
817
+ certainty[certainty > upper_thresh] = 1
818
+ good_samples = torch.multinomial(certainty, num_samples=min(expansion_factor * M, len(certainty)), replacement=False)
819
+ timings['good_samples_selection'].append(time.time() - start)
820
+
821
+ # =================== Step 3: Triangulate Keypoints ===================
822
+ reference_image_dict = {
823
+ "triangulated_points": [],
824
+ "triangulated_points_errors_proj1": [],
825
+ "triangulated_points_errors_proj2": []
826
+ }
827
+
828
+ start = time.time()
829
+ matches_NN = warp.reshape(-1, 4)[good_samples]
830
+
831
+ # Convert matches to pixel coordinates
832
+ kptsA_np, kptsB_np, kptsA_color, kptsB_color = extract_keypoints_and_colors_single(
833
+ np.array(imA).astype(np.uint8),
834
+ np.array(imB).astype(np.uint8),
835
+ matches_NN,
836
+ roma_model
837
+ )
838
+
839
+ proj_matrices_A = viewpoint_stack[source_idx].full_proj_transform
840
+ proj_matrices_B = viewpoint_stack[closest_indices_selected[source_idx, 0]].full_proj_transform
841
+
842
+ triangulated_points, triangulated_points_errors_proj1, triangulated_points_errors_proj2 = triangulate_points(
843
+ P1=torch.stack([proj_matrices_A] * M, axis=0),
844
+ P2=torch.stack([proj_matrices_B] * M, axis=0),
845
+ k1_x=kptsA_np[:M, 0], k1_y=kptsA_np[:M, 1],
846
+ k2_x=kptsB_np[:M, 0], k2_y=kptsB_np[:M, 1])
847
+
848
+ reference_image_dict["triangulated_points"].append(triangulated_points)
849
+ reference_image_dict["triangulated_points_errors_proj1"].append(triangulated_points_errors_proj1)
850
+ reference_image_dict["triangulated_points_errors_proj2"].append(triangulated_points_errors_proj2)
851
+ timings['triangulation_per_NN'].append(time.time() - start)
852
+
853
+ # =================== Step 4: Select Best Triangulated Points ===================
854
+ start = time.time()
855
+ NNs_triangulated_points_selected, NNs_triangulated_points_selected_proj_errors = select_best_keypoints(
856
+ NNs_triangulated_points=torch.stack(reference_image_dict["triangulated_points"], dim=0),
857
+ NNs_errors_proj1=np.stack(reference_image_dict["triangulated_points_errors_proj1"], axis=0),
858
+ NNs_errors_proj2=np.stack(reference_image_dict["triangulated_points_errors_proj2"], axis=0))
859
+ timings['select_best_keypoints'].append(time.time() - start)
860
+
861
+ # =================== Step 5: Create New Gaussians ===================
862
+ start = time.time()
863
+ viewpoint_cam1 = viewpoint_stack[source_idx]
864
+ N = len(NNs_triangulated_points_selected)
865
+ new_xyz = NNs_triangulated_points_selected[:, :-1]
866
+ all_new_xyz.append(new_xyz)
867
+ all_new_features_dc.append(RGB2SH(torch.tensor(kptsA_color.astype(np.float32) / 255.)).unsqueeze(1))
868
+ all_new_features_rest.append(torch.stack([gaussians._features_rest[-1].clone().detach() * 0.] * N, dim=0))
869
+
870
+ mask_bad_points = torch.tensor(
871
+ NNs_triangulated_points_selected_proj_errors > keypoint_fit_error_tolerance,
872
+ dtype=torch.float32).unsqueeze(1).to(device)
873
+
874
+ all_new_opacities.append(torch.stack([gaussians._opacity[-1].clone().detach()] * N, dim=0) * 0. - mask_bad_points * (1e1))
875
+
876
+ dist_points_to_cam1 = torch.linalg.norm(viewpoint_cam1.camera_center.clone().detach() - new_xyz, dim=1, ord=2)
877
+ all_new_scaling.append(gaussians.scaling_inverse_activation((dist_points_to_cam1 * scaling_factor).unsqueeze(1).repeat(1, 3)))
878
+ all_new_rotation.append(torch.stack([gaussians._rotation[-1].clone().detach()] * N, dim=0))
879
+ timings['save_gaussians'].append(time.time() - start)
880
+
881
+ # =================== Final Densification Postfix ===================
882
+ start = time.time()
883
+ all_new_xyz = torch.cat(all_new_xyz, dim=0)
884
+ all_new_features_dc = torch.cat(all_new_features_dc, dim=0)
885
+ new_tmp_radii = torch.zeros(all_new_xyz.shape[0])
886
+ prune_mask = torch.ones(all_new_xyz.shape[0], dtype=torch.bool)
887
+
888
+ gaussians.densification_postfix(
889
+ all_new_xyz[prune_mask].to(device),
890
+ all_new_features_dc[prune_mask].to(device),
891
+ torch.cat(all_new_features_rest, dim=0)[prune_mask].to(device),
892
+ torch.cat(all_new_opacities, dim=0)[prune_mask].to(device),
893
+ torch.cat(all_new_scaling, dim=0)[prune_mask].to(device),
894
+ torch.cat(all_new_rotation, dim=0)[prune_mask].to(device),
895
+ new_tmp_radii[prune_mask].to(device)
896
+ )
897
+ timings['final_densification_postfix'].append(time.time() - start)
898
+
899
+ # =================== Print Profiling Results ===================
900
+ print("\n=== Profiling Summary (average per frame) ===")
901
+ for key, times in timings.items():
902
+ print(f"{key:35s}: {sum(times) / len(times):.4f} sec (total {sum(times):.2f} sec)")
903
+
904
+ return viewpoint_stack, closest_indices_selected, visualizations
source/data_utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def scene_cameras_train_test_split(scene, verbose=False):
2
+ """
3
+ Iterate over resolutions in the scene. For each resolution check if this resolution has test_cameras
4
+ if it doesn't then extract every 8th camera from the train and put it to the test set. This follows the
5
+ evaluation protocol suggested by Kerbl et al. in the seminal work on 3DGS. All changes to the input
6
+ object scene are inplace changes.
7
+ :param scene: Scene Class object from the gaussian-splatting.scene module
8
+ :param verbose: Print initial and final stage of the function
9
+ :return: None
10
+
11
+ """
12
+ if verbose: print("Preparing train and test sets split...")
13
+ for resolution in scene.train_cameras.keys():
14
+ if len(scene.test_cameras[resolution]) == 0:
15
+ if verbose:
16
+ print(f"Found no test_cameras for resolution {resolution}. Move every 8th camera out ouf total "+\
17
+ f"{len(scene.train_cameras[resolution])} train cameras to the test set now")
18
+ N = len(scene.train_cameras[resolution])
19
+ scene.test_cameras[resolution] = [scene.train_cameras[resolution][idx] for idx in range(0, N)
20
+ if idx % 8 == 0]
21
+ scene.train_cameras[resolution] = [scene.train_cameras[resolution][idx] for idx in range(0, N)
22
+ if idx % 8 != 0]
23
+ if verbose:
24
+ print(f"Done. Now train and test sets contain each {len(scene.train_cameras[resolution])} and " + \
25
+ f"{len(scene.test_cameras[resolution])} cameras respectively.")
26
+
27
+
28
+ return
source/losses.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code is copied from the gaussian-splatting/utils/loss_utils.py
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.autograd import Variable
6
+ from math import exp
7
+
8
+ def l1_loss(network_output, gt, mean=True):
9
+ return torch.abs((network_output - gt)).mean() if mean else torch.abs((network_output - gt))
10
+
11
+ def l2_loss(network_output, gt):
12
+ return ((network_output - gt) ** 2).mean()
13
+
14
+ def gaussian(window_size, sigma):
15
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
16
+ return gauss / gauss.sum()
17
+
18
+ def create_window(window_size, channel):
19
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
20
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
21
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
22
+ return window
23
+
24
+ def ssim(img1, img2, window_size=11, size_average=True, mask = None):
25
+ channel = img1.size(-3)
26
+ window = create_window(window_size, channel)
27
+
28
+ if img1.is_cuda:
29
+ window = window.cuda(img1.get_device())
30
+ window = window.type_as(img1)
31
+
32
+ return _ssim(img1, img2, window, window_size, channel, size_average, mask)
33
+
34
+ def _ssim(img1, img2, window, window_size, channel, size_average=True, mask = None):
35
+ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
36
+ mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
37
+
38
+ mu1_sq = mu1.pow(2)
39
+ mu2_sq = mu2.pow(2)
40
+ mu1_mu2 = mu1 * mu2
41
+
42
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
43
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
44
+ sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
45
+
46
+ C1 = 0.01 ** 2
47
+ C2 = 0.03 ** 2
48
+
49
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
50
+
51
+ if mask is not None:
52
+ ssim_map = ssim_map * mask
53
+
54
+ if size_average:
55
+ return ssim_map.mean()
56
+ else:
57
+ return ssim_map.mean(1).mean(1).mean(1)
58
+
59
+
60
+ def mse(img1, img2):
61
+ return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
62
+
63
+ def psnr(img1, img2):
64
+ """
65
+ Computes the Peak Signal-to-Noise Ratio (PSNR) between two single images. NOT BATCHED!
66
+ Args:
67
+ img1 (torch.Tensor): The first image tensor, with pixel values scaled between 0 and 1.
68
+ Shape should be (channels, height, width).
69
+ img2 (torch.Tensor): The second image tensor with the same shape as img1, used for comparison.
70
+
71
+ Returns:
72
+ torch.Tensor: A scalar tensor containing the PSNR value in decibels (dB).
73
+ """
74
+ mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
75
+ return 20 * torch.log10(1.0 / torch.sqrt(mse))
76
+
77
+
78
+ def tv_loss(image):
79
+ """
80
+ Computes the total variation (TV) loss for an image of shape [3, H, W].
81
+
82
+ Args:
83
+ image (torch.Tensor): Input image of shape [3, H, W]
84
+
85
+ Returns:
86
+ torch.Tensor: Scalar value representing the total variation loss.
87
+ """
88
+ # Ensure the image has the correct dimensions
89
+ assert image.ndim == 3 and image.shape[0] == 3, "Input must be of shape [3, H, W]"
90
+
91
+ # Compute the difference between adjacent pixels in the x-direction (width)
92
+ diff_x = torch.abs(image[:, :, 1:] - image[:, :, :-1])
93
+
94
+ # Compute the difference between adjacent pixels in the y-direction (height)
95
+ diff_y = torch.abs(image[:, 1:, :] - image[:, :-1, :])
96
+
97
+ # Sum the total variation in both directions
98
+ tv_loss_value = torch.mean(diff_x) + torch.mean(diff_y)
99
+
100
+ return tv_loss_value
source/networks.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import sys
4
+ sys.path.append('./submodules/gaussian-splatting/')
5
+
6
+ from random import randint
7
+ from scene import Scene, GaussianModel
8
+ from gaussian_renderer import render
9
+ from source.data_utils import scene_cameras_train_test_split
10
+
11
+ class Warper3DGS(torch.nn.Module):
12
+ def __init__(self, sh_degree, opt, pipe, dataset, viewpoint_stack, verbose,
13
+ do_train_test_split=True):
14
+ super(Warper3DGS, self).__init__()
15
+ """
16
+ Init Warper using all the objects necessary for rendering gaussian splats.
17
+ Here we merely link class objects to the objects instantiated outsided the class.
18
+ """
19
+ print("ready!!!7")
20
+ self.gaussians = GaussianModel(sh_degree)
21
+ print("ready!!!8")
22
+ self.gaussians.tmp_radii = torch.zeros((self.gaussians.get_xyz.shape[0]), device="cuda")
23
+ self.render = render
24
+ self.gs_config_opt = opt
25
+ bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
26
+ self.bg = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
27
+ self.pipe = pipe
28
+ print("ready!!!")
29
+ self.scene = Scene(dataset, self.gaussians, shuffle=False)
30
+ print("ready2")
31
+ if do_train_test_split:
32
+ scene_cameras_train_test_split(self.scene, verbose=verbose)
33
+
34
+ self.gaussians.training_setup(opt)
35
+ self.viewpoint_stack = viewpoint_stack
36
+ if not self.viewpoint_stack:
37
+ self.viewpoint_stack = self.scene.getTrainCameras().copy()
38
+
39
+ def forward(self, viewpoint_cam=None):
40
+ """
41
+ For a provided camera viewpoint_cam we render gaussians from this viewpoint.
42
+ If no camera provided then we use the self.viewpoint_stack (list of cameras).
43
+ If the latter is empty we reinitialize it using the self.scene object.
44
+ """
45
+ if not viewpoint_cam:
46
+ if not self.viewpoint_stack:
47
+ self.viewpoint_stack = self.scene.getTrainCameras().copy()
48
+ viewpoint_cam = self.viewpoint_stack[randint(0, len(self.viewpoint_stack) - 1)]
49
+
50
+ render_pkg = self.render(viewpoint_cam, self.gaussians, self.pipe, self.bg)
51
+ return render_pkg
52
+
source/timer.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ class Timer:
3
+ def __init__(self):
4
+ self.start_time = None
5
+ self.elapsed = 0
6
+ self.paused = False
7
+
8
+ def start(self):
9
+ if self.start_time is None:
10
+ self.start_time = time.time()
11
+ elif self.paused:
12
+ self.start_time = time.time() - self.elapsed
13
+ self.paused = False
14
+
15
+ def pause(self):
16
+ if not self.paused:
17
+ self.elapsed = time.time() - self.start_time
18
+ self.paused = True
19
+
20
+ def get_elapsed_time(self):
21
+ if self.paused:
22
+ return self.elapsed
23
+ else:
24
+ return time.time() - self.start_time
source/trainer.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from random import randint
3
+ from tqdm.rich import trange
4
+ from tqdm import tqdm as tqdm
5
+ from source.networks import Warper3DGS
6
+ import wandb
7
+ import sys
8
+
9
+ sys.path.append('./submodules/gaussian-splatting/')
10
+ import lpips
11
+ from source.losses import ssim, l1_loss, psnr
12
+ from rich.console import Console
13
+ from rich.theme import Theme
14
+
15
+ custom_theme = Theme({
16
+ "info": "dim cyan",
17
+ "warning": "magenta",
18
+ "danger": "bold red"
19
+ })
20
+
21
+ #from source.corr_init import init_gaussians_with_corr
22
+ from source.corr_init_new import init_gaussians_with_corr_profiled as init_gaussians_with_corr
23
+ from source.utils_aux import log_samples
24
+
25
+ from source.timer import Timer
26
+
27
+ class EDGSTrainer:
28
+ def __init__(self,
29
+ GS: Warper3DGS,
30
+ training_config,
31
+ dataset_white_background=False,
32
+ device=torch.device('cuda'),
33
+ log_wandb=True,
34
+ ):
35
+ self.GS = GS
36
+ self.scene = GS.scene
37
+ self.viewpoint_stack = GS.viewpoint_stack
38
+ self.gaussians = GS.gaussians
39
+
40
+ self.training_config = training_config
41
+ self.GS_optimizer = GS.gaussians.optimizer
42
+ self.dataset_white_background = dataset_white_background
43
+
44
+ self.training_step = 1
45
+ self.gs_step = 0
46
+ self.CONSOLE = Console(width=120, theme=custom_theme)
47
+ self.saving_iterations = training_config.save_iterations
48
+ self.evaluate_iterations = None
49
+ self.batch_size = training_config.batch_size
50
+ self.ema_loss_for_log = 0.0
51
+
52
+ # Logs in the format {step:{"loss1":loss1_value, "loss2":loss2_value}}
53
+ self.logs_losses = {}
54
+ self.lpips = lpips.LPIPS(net='vgg').to(device)
55
+ self.device = device
56
+ self.timer = Timer()
57
+ self.log_wandb = log_wandb
58
+
59
+ def load_checkpoints(self, load_cfg):
60
+ # Load 3DGS checkpoint
61
+ if load_cfg.gs:
62
+ self.gs.gaussians.restore(
63
+ torch.load(f"{load_cfg.gs}/chkpnt{load_cfg.gs_step}.pth")[0],
64
+ self.training_config)
65
+ self.GS_optimizer = self.GS.gaussians.optimizer
66
+ self.CONSOLE.print(f"3DGS loaded from checkpoint for iteration {load_cfg.gs_step}",
67
+ style="info")
68
+ self.training_step += load_cfg.gs_step
69
+ self.gs_step += load_cfg.gs_step
70
+
71
+ def train(self, train_cfg):
72
+ # 3DGS training
73
+ self.CONSOLE.print("Train 3DGS for {} iterations".format(train_cfg.gs_epochs), style="info")
74
+ with trange(self.training_step, self.training_step + train_cfg.gs_epochs, desc="[green]Train gaussians") as progress_bar:
75
+ for self.training_step in progress_bar:
76
+ radii = self.train_step_gs(max_lr=train_cfg.max_lr, no_densify=train_cfg.no_densify)
77
+ with torch.no_grad():
78
+ if train_cfg.no_densify:
79
+ self.prune(radii)
80
+ else:
81
+ self.densify_and_prune(radii)
82
+ if train_cfg.reduce_opacity:
83
+ # Slightly reduce opacity every few steps:
84
+ if self.gs_step < self.training_config.densify_until_iter and self.gs_step % 10 == 0:
85
+ opacities_new = torch.log(torch.exp(self.GS.gaussians._opacity.data) * 0.99)
86
+ self.GS.gaussians._opacity.data = opacities_new
87
+ self.timer.pause()
88
+ # Progress bar
89
+ if self.training_step % 10 == 0:
90
+ progress_bar.set_postfix({"[red]Loss": f"{self.ema_loss_for_log:.{7}f}"}, refresh=True)
91
+ # Log and save
92
+ if self.training_step in self.saving_iterations:
93
+ self.save_model()
94
+ if self.evaluate_iterations is not None:
95
+ if self.training_step in self.evaluate_iterations:
96
+ self.evaluate()
97
+ else:
98
+ if (self.training_step <= 3000 and self.training_step % 500 == 0) or \
99
+ (self.training_step > 3000 and self.training_step % 1000 == 228) :
100
+ self.evaluate()
101
+
102
+ self.timer.start()
103
+
104
+
105
+ def evaluate(self):
106
+ torch.cuda.empty_cache()
107
+ log_gen_images, log_real_images = [], []
108
+ validation_configs = ({'name': 'test', 'cameras': self.scene.getTestCameras(), 'cam_idx': self.training_config.TEST_CAM_IDX_TO_LOG},
109
+ {'name': 'train',
110
+ 'cameras': [self.scene.getTrainCameras()[idx % len(self.scene.getTrainCameras())] for idx in
111
+ range(0, 150, 5)], 'cam_idx': 10})
112
+ if self.log_wandb:
113
+ wandb.log({f"Number of Gaussians": len(self.GS.gaussians._xyz)}, step=self.training_step)
114
+ for config in validation_configs:
115
+ if config['cameras'] and len(config['cameras']) > 0:
116
+ l1_test = 0.0
117
+ psnr_test = 0.0
118
+ ssim_test = 0.0
119
+ lpips_splat_test = 0.0
120
+ for idx, viewpoint in enumerate(config['cameras']):
121
+ image = torch.clamp(self.GS(viewpoint)["render"], 0.0, 1.0)
122
+ gt_image = torch.clamp(viewpoint.original_image.to(self.device), 0.0, 1.0)
123
+ l1_test += l1_loss(image, gt_image).double()
124
+ psnr_test += psnr(image.unsqueeze(0), gt_image.unsqueeze(0)).double()
125
+ ssim_test += ssim(image, gt_image).double()
126
+ lpips_splat_test += self.lpips(image, gt_image).detach().double()
127
+ if idx in [config['cam_idx']]:
128
+ log_gen_images.append(image)
129
+ log_real_images.append(gt_image)
130
+ psnr_test /= len(config['cameras'])
131
+ l1_test /= len(config['cameras'])
132
+ ssim_test /= len(config['cameras'])
133
+ lpips_splat_test /= len(config['cameras'])
134
+ if self.log_wandb:
135
+ wandb.log({f"{config['name']}/L1": l1_test.item(), f"{config['name']}/PSNR": psnr_test.item(), \
136
+ f"{config['name']}/SSIM": ssim_test.item(), f"{config['name']}/LPIPS_splat": lpips_splat_test.item()}, step = self.training_step)
137
+ self.CONSOLE.print("\n[ITER {}], #{} gaussians, Evaluating {}: L1={:.6f}, PSNR={:.6f}, SSIM={:.6f}, LPIPS_splat={:.6f} ".format(
138
+ self.training_step, len(self.GS.gaussians._xyz), config['name'], l1_test.item(), psnr_test.item(), ssim_test.item(), lpips_splat_test.item()), style="info")
139
+ if self.log_wandb:
140
+ with torch.no_grad():
141
+ log_samples(torch.stack((log_real_images[0],log_gen_images[0])) , [], self.training_step, caption="Real and Generated Samples")
142
+ wandb.log({"time": self.timer.get_elapsed_time()}, step=self.training_step)
143
+ torch.cuda.empty_cache()
144
+
145
+ def train_step_gs(self, max_lr = False, no_densify = False):
146
+ self.gs_step += 1
147
+ if max_lr:
148
+ self.GS.gaussians.update_learning_rate(max(self.gs_step, 8_000))
149
+ else:
150
+ self.GS.gaussians.update_learning_rate(self.gs_step)
151
+ # Every 1000 its we increase the levels of SH up to a maximum degree
152
+ if self.gs_step % 1000 == 0:
153
+ self.GS.gaussians.oneupSHdegree()
154
+
155
+ # Pick a random Camera
156
+ if not self.viewpoint_stack:
157
+ self.viewpoint_stack = self.scene.getTrainCameras().copy()
158
+ viewpoint_cam = self.viewpoint_stack.pop(randint(0, len(self.viewpoint_stack) - 1))
159
+
160
+ render_pkg = self.GS(viewpoint_cam=viewpoint_cam)
161
+ image = render_pkg["render"]
162
+ # Loss
163
+ gt_image = viewpoint_cam.original_image.to(self.device)
164
+ L1_loss = l1_loss(image, gt_image)
165
+
166
+ ssim_loss = (1.0 - ssim(image, gt_image))
167
+ loss = (1.0 - self.training_config.lambda_dssim) * L1_loss + \
168
+ self.training_config.lambda_dssim * ssim_loss
169
+ self.timer.pause()
170
+ self.logs_losses[self.training_step] = {"loss": loss.item(),
171
+ "L1_loss": L1_loss.item(),
172
+ "ssim_loss": ssim_loss.item()}
173
+
174
+ if self.log_wandb:
175
+ for k, v in self.logs_losses[self.training_step].items():
176
+ wandb.log({f"train/{k}": v}, step=self.training_step)
177
+ self.ema_loss_for_log = 0.4 * self.logs_losses[self.training_step]["loss"] + 0.6 * self.ema_loss_for_log
178
+ self.timer.start()
179
+ self.GS_optimizer.zero_grad(set_to_none=True)
180
+ loss.backward()
181
+ with torch.no_grad():
182
+ if self.gs_step < self.training_config.densify_until_iter and not no_densify:
183
+ self.GS.gaussians.max_radii2D[render_pkg["visibility_filter"]] = torch.max(
184
+ self.GS.gaussians.max_radii2D[render_pkg["visibility_filter"]],
185
+ render_pkg["radii"][render_pkg["visibility_filter"]])
186
+ self.GS.gaussians.add_densification_stats(render_pkg["viewspace_points"],
187
+ render_pkg["visibility_filter"])
188
+
189
+ # Optimizer step
190
+ self.GS_optimizer.step()
191
+ self.GS_optimizer.zero_grad(set_to_none=True)
192
+ return render_pkg["radii"]
193
+
194
+ def densify_and_prune(self, radii = None):
195
+ # Densification or pruning
196
+ if self.gs_step < self.training_config.densify_until_iter:
197
+ if (self.gs_step > self.training_config.densify_from_iter) and \
198
+ (self.gs_step % self.training_config.densification_interval == 0):
199
+ size_threshold = 20 if self.gs_step > self.training_config.opacity_reset_interval else None
200
+ self.GS.gaussians.densify_and_prune(self.training_config.densify_grad_threshold,
201
+ 0.005,
202
+ self.GS.scene.cameras_extent,
203
+ size_threshold, radii)
204
+ if self.gs_step % self.training_config.opacity_reset_interval == 0 or (
205
+ self.dataset_white_background and self.gs_step == self.training_config.densify_from_iter):
206
+ self.GS.gaussians.reset_opacity()
207
+
208
+
209
+
210
+ def save_model(self):
211
+ print("\n[ITER {}] Saving Gaussians".format(self.gs_step))
212
+ self.scene.save(self.gs_step)
213
+ print("\n[ITER {}] Saving Checkpoint".format(self.gs_step))
214
+ torch.save((self.GS.gaussians.capture(), self.gs_step),
215
+ self.scene.model_path + "/chkpnt" + str(self.gs_step) + ".pth")
216
+
217
+
218
+ def init_with_corr(self, cfg, verbose=False, roma_model=None):
219
+ """
220
+ Initializes image with matchings. Also removes SfM init points.
221
+ Args:
222
+ cfg: configuration part named init_wC. Check train.yaml
223
+ verbose: whether you want to print intermediate results. Useful for debug.
224
+ roma_model: optionally you can pass here preinit RoMA model to avoid reinit
225
+ it every time.
226
+ """
227
+ if not cfg.use:
228
+ return None
229
+ N_splats_at_init = len(self.GS.gaussians._xyz)
230
+ print("N_splats_at_init:", N_splats_at_init)
231
+ camera_set, selected_indices, visualization_dict = init_gaussians_with_corr(
232
+ self.GS.gaussians,
233
+ self.scene,
234
+ cfg,
235
+ self.device,
236
+ verbose=verbose,
237
+ roma_model=roma_model)
238
+
239
+ # Remove SfM points and leave only matchings inits
240
+ if not cfg.add_SfM_init:
241
+ with torch.no_grad():
242
+ N_splats_after_init = len(self.GS.gaussians._xyz)
243
+ print("N_splats_after_init:", N_splats_after_init)
244
+ self.gaussians.tmp_radii = torch.zeros(self.gaussians._xyz.shape[0]).to(self.device)
245
+ mask = torch.concat([torch.ones(N_splats_at_init, dtype=torch.bool),
246
+ torch.zeros(N_splats_after_init-N_splats_at_init, dtype=torch.bool)],
247
+ axis=0)
248
+ self.GS.gaussians.prune_points(mask)
249
+ with torch.no_grad():
250
+ gaussians = self.gaussians
251
+ gaussians._scaling = gaussians.scaling_inverse_activation(gaussians.scaling_activation(gaussians._scaling)*0.5)
252
+ return visualization_dict
253
+
254
+
255
+ def prune(self, radii, min_opacity=0.005):
256
+ self.GS.gaussians.tmp_radii = radii
257
+ if self.gs_step < self.training_config.densify_until_iter:
258
+ prune_mask = (self.GS.gaussians.get_opacity < min_opacity).squeeze()
259
+ self.GS.gaussians.prune_points(prune_mask)
260
+ torch.cuda.empty_cache()
261
+ self.GS.gaussians.tmp_radii = None
262
+
source/utils_aux.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Perlin noise code taken from https://gist.github.com/adefossez/0646dbe9ed4005480a2407c62aac8869
2
+ from types import SimpleNamespace
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+ import torchvision
7
+ import wandb
8
+ import random
9
+ import torchvision.transforms as T
10
+ import torchvision.transforms.functional as F
11
+ import torch
12
+ from PIL import Image
13
+
14
+ def parse_dict_to_namespace(dict_nested):
15
+ """Turns nested dictionary into nested namespaces"""
16
+ if type(dict_nested) != dict and type(dict_nested) != list: return dict_nested
17
+ x = SimpleNamespace()
18
+ for key, val in dict_nested.items():
19
+ if type(val) == dict:
20
+ setattr(x, key, parse_dict_to_namespace(val))
21
+ elif type(val) == list:
22
+ setattr(x, key, [parse_dict_to_namespace(v) for v in val])
23
+ else:
24
+ setattr(x, key, val)
25
+ return x
26
+
27
+ def set_seed(seed=42, cuda=True):
28
+ random.seed(seed)
29
+ np.random.seed(seed)
30
+ torch.manual_seed(seed)
31
+ if cuda:
32
+ torch.cuda.manual_seed_all(seed)
33
+
34
+
35
+
36
+ def log_samples(samples, scores, iteration, caption="Real Samples"):
37
+ # Create a grid of images
38
+ grid = torchvision.utils.make_grid(samples)
39
+
40
+ # Log the images and scores to wandb
41
+ wandb.log({
42
+ f"{caption}_images": [wandb.Image(grid, caption=f"{caption}: {scores}")],
43
+ }, step = iteration)
44
+
45
+
46
+
47
+ def pairwise_distances(matrix):
48
+ """
49
+ Computes the pairwise Euclidean distances between all vectors in the input matrix.
50
+
51
+ Args:
52
+ matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
53
+
54
+ Returns:
55
+ torch.Tensor: Pairwise distance matrix of shape [N, N].
56
+ """
57
+ # Compute squared pairwise distances
58
+ squared_diff = torch.cdist(matrix, matrix, p=2)
59
+ return squared_diff
60
+
61
+ def k_closest_vectors(matrix, k):
62
+ """
63
+ Finds the k-closest vectors for each vector in the input matrix based on Euclidean distance.
64
+
65
+ Args:
66
+ matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
67
+ k (int): Number of closest vectors to return for each vector.
68
+
69
+ Returns:
70
+ torch.Tensor: Indices of the k-closest vectors for each vector, excluding the vector itself.
71
+ """
72
+ # Compute pairwise distances
73
+ distances = pairwise_distances(matrix)
74
+
75
+ # For each vector, sort distances and get the indices of the k-closest vectors (excluding itself)
76
+ # Set diagonal distances to infinity to exclude the vector itself from the nearest neighbors
77
+ distances.fill_diagonal_(float('inf'))
78
+
79
+ # Get the indices of the k smallest distances (k-closest vectors)
80
+ _, indices = torch.topk(distances, k, largest=False, dim=1)
81
+
82
+ return indices
83
+
84
+ def process_image(image_tensor):
85
+ image_np = image_tensor.detach().cpu().numpy().transpose(1, 2, 0)
86
+ return Image.fromarray(np.clip(image_np * 255, 0, 255).astype(np.uint8))
87
+
88
+
89
+ def normalize_keypoints(kpts_np, width, height):
90
+ kpts_np[:, 0] = kpts_np[:, 0] / width * 2. - 1.
91
+ kpts_np[:, 1] = kpts_np[:, 1] / height * 2. - 1.
92
+ return kpts_np
source/utils_preprocess.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains function for video or image collection preprocessing.
2
+ # For video we do the preprocessing and select k sharpest frames.
3
+ # Afterwards scene is constructed
4
+ import cv2
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ import pycolmap
8
+ import os
9
+ import time
10
+ import tempfile
11
+ from moviepy import VideoFileClip
12
+ from matplotlib import pyplot as plt
13
+ from PIL import Image
14
+ import cv2
15
+ from tqdm import tqdm
16
+
17
+ WORKDIR = "../outputs/"
18
+
19
+
20
+ def get_rotation_moviepy(video_path):
21
+ clip = VideoFileClip(video_path)
22
+ rotation = 0
23
+
24
+ try:
25
+ displaymatrix = clip.reader.infos['inputs'][0]['streams'][2]['metadata'].get('displaymatrix', '')
26
+ if 'rotation of' in displaymatrix:
27
+ angle = float(displaymatrix.strip().split('rotation of')[-1].split('degrees')[0])
28
+ rotation = int(angle) % 360
29
+
30
+ except Exception as e:
31
+ print(f"No displaymatrix rotation found: {e}")
32
+
33
+ clip.reader.close()
34
+ #if clip.audio:
35
+ # clip.audio.reader.close_proc()
36
+
37
+ return rotation
38
+
39
+ def resize_max_side(frame, max_size):
40
+ h, w = frame.shape[:2]
41
+ scale = max_size / max(h, w)
42
+ if scale < 1:
43
+ frame = cv2.resize(frame, (int(w * scale), int(h * scale)))
44
+ return frame
45
+
46
+ def read_video_frames(video_input, k=1, max_size=1024):
47
+ """
48
+ Extracts every k-th frame from a video or list of images, resizes to max size, and returns frames as list.
49
+
50
+ Parameters:
51
+ video_input (str, file-like, or list): Path to video file, file-like object, or list of image files.
52
+ k (int): Interval for frame extraction (every k-th frame).
53
+ max_size (int): Maximum size for width or height after resizing.
54
+
55
+ Returns:
56
+ frames (list): List of resized frames (numpy arrays).
57
+ """
58
+ # Handle list of image files (not single video in a list)
59
+ if isinstance(video_input, list):
60
+ # If it's a single video in a list, treat it as video
61
+ if len(video_input) == 1 and video_input[0].name.endswith(('.mp4', '.avi', '.mov')):
62
+ video_input = video_input[0] # unwrap single video file
63
+ else:
64
+ # Treat as list of images
65
+ frames = []
66
+ for img_file in video_input:
67
+ img = Image.open(img_file.name).convert("RGB")
68
+ img.thumbnail((max_size, max_size))
69
+ frames.append(np.array(img)[...,::-1])
70
+ return frames
71
+
72
+ # Handle file-like or path
73
+ if hasattr(video_input, 'name'):
74
+ video_path = video_input.name
75
+ elif isinstance(video_input, (str, os.PathLike)):
76
+ video_path = str(video_input)
77
+ else:
78
+ raise ValueError("Unsupported video input type. Must be a filepath, file-like object, or list of images.")
79
+
80
+
81
+ cap = cv2.VideoCapture(video_path)
82
+ if not cap.isOpened():
83
+ raise ValueError(f"Error: Could not open video {video_path}.")
84
+
85
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
86
+ frame_count = 0
87
+ frames = []
88
+
89
+ with tqdm(total=total_frames // k, desc="Processing Video", unit="frame") as pbar:
90
+ while True:
91
+ ret, frame = cap.read()
92
+ if not ret:
93
+ break
94
+ if frame_count % k == 0:
95
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
96
+ h, w = frame.shape[:2]
97
+ scale = max(h, w) / max_size
98
+ if scale > 1:
99
+ frame = cv2.resize(frame, (int(w / scale), int(h / scale)))
100
+ frames.append(frame[...,[2,1,0]])
101
+ pbar.update(1)
102
+ frame_count += 1
103
+
104
+ cap.release()
105
+ return frames
106
+
107
+ def resize_max_side(frame, max_size):
108
+ """
109
+ Resizes the frame so that its largest side equals max_size, maintaining aspect ratio.
110
+ """
111
+ height, width = frame.shape[:2]
112
+ max_dim = max(height, width)
113
+
114
+ if max_dim <= max_size:
115
+ return frame # No need to resize
116
+
117
+ scale = max_size / max_dim
118
+ new_width = int(width * scale)
119
+ new_height = int(height * scale)
120
+
121
+ resized_frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA)
122
+ return resized_frame
123
+
124
+
125
+
126
+ def variance_of_laplacian(image):
127
+ # compute the Laplacian of the image and then return the focus
128
+ # measure, which is simply the variance of the Laplacian
129
+ return cv2.Laplacian(image, cv2.CV_64F).var()
130
+
131
+ def process_all_frames(IMG_FOLDER = '/scratch/datasets/hq_data/night2_all_frames',
132
+ to_visualize=False,
133
+ save_images=True):
134
+ dict_scores = {}
135
+ for idx, img_name in tqdm(enumerate(sorted([x for x in os.listdir(IMG_FOLDER) if '.png' in x]))):
136
+
137
+ img = cv2.imread(os.path.join(IMG_FOLDER, img_name))#[250:, 100:]
138
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
139
+ fm = variance_of_laplacian(gray) + \
140
+ variance_of_laplacian(cv2.resize(gray, (0,0), fx=0.75, fy=0.75)) + \
141
+ variance_of_laplacian(cv2.resize(gray, (0,0), fx=0.5, fy=0.5)) + \
142
+ variance_of_laplacian(cv2.resize(gray, (0,0), fx=0.25, fy=0.25))
143
+ if to_visualize:
144
+ plt.figure()
145
+ plt.title(f"Laplacian score: {fm:.2f}")
146
+ plt.imshow(img[..., [2,1,0]])
147
+ plt.show()
148
+ dict_scores[idx] = {"idx" : idx,
149
+ "img_name" : img_name,
150
+ "score" : fm}
151
+ if save_images:
152
+ dict_scores[idx]["img"] = img
153
+
154
+ return dict_scores
155
+
156
+ def select_optimal_frames(scores, k):
157
+ """
158
+ Selects a minimal subset of frames while ensuring no gaps exceed k.
159
+
160
+ Args:
161
+ scores (list of float): List of scores where index represents frame number.
162
+ k (int): Maximum allowed gap between selected frames.
163
+
164
+ Returns:
165
+ list of int: Indices of selected frames.
166
+ """
167
+ n = len(scores)
168
+ selected = [0, n-1]
169
+ i = 0 # Start at the first frame
170
+
171
+ while i < n:
172
+ # Find the best frame to select within the next k frames
173
+ best_idx = max(range(i, min(i + k + 1, n)), key=lambda x: scores[x], default=None)
174
+
175
+ if best_idx is None:
176
+ break # No more frames left
177
+
178
+ selected.append(best_idx)
179
+ i = best_idx + k + 1 # Move forward, ensuring gaps stay within k
180
+
181
+ return sorted(selected)
182
+
183
+
184
+ def variance_of_laplacian(image):
185
+ """
186
+ Compute the variance of Laplacian as a focus measure.
187
+ """
188
+ return cv2.Laplacian(image, cv2.CV_64F).var()
189
+
190
+ def preprocess_frames(frames, verbose=False):
191
+ """
192
+ Compute sharpness scores for a list of frames using multi-scale Laplacian variance.
193
+
194
+ Args:
195
+ frames (list of np.ndarray): List of frames (BGR images).
196
+ verbose (bool): If True, print scores.
197
+
198
+ Returns:
199
+ list of float: Sharpness scores for each frame.
200
+ """
201
+ scores = []
202
+
203
+ for idx, frame in enumerate(tqdm(frames, desc="Scoring frames")):
204
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
205
+
206
+ fm = (
207
+ variance_of_laplacian(gray) +
208
+ variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.75, fy=0.75)) +
209
+ variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.5, fy=0.5)) +
210
+ variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.25, fy=0.25))
211
+ )
212
+
213
+ if verbose:
214
+ print(f"Frame {idx}: Sharpness Score = {fm:.2f}")
215
+
216
+ scores.append(fm)
217
+
218
+ return scores
219
+
220
+ def select_optimal_frames(scores, k):
221
+ """
222
+ Selects k frames by splitting into k segments and picking the sharpest frame from each.
223
+
224
+ Args:
225
+ scores (list of float): List of sharpness scores.
226
+ k (int): Number of frames to select.
227
+
228
+ Returns:
229
+ list of int: Indices of selected frames.
230
+ """
231
+ n = len(scores)
232
+ selected_indices = []
233
+ segment_size = n // k
234
+
235
+ for i in range(k):
236
+ start = i * segment_size
237
+ end = (i + 1) * segment_size if i < k - 1 else n # Last chunk may be larger
238
+ segment_scores = scores[start:end]
239
+
240
+ if len(segment_scores) == 0:
241
+ continue # Safety check if some segment is empty
242
+
243
+ best_in_segment = start + np.argmax(segment_scores)
244
+ selected_indices.append(best_in_segment)
245
+
246
+ return sorted(selected_indices)
247
+
248
+ def save_frames_to_scene_dir(frames, scene_dir):
249
+ """
250
+ Saves a list of frames into the target scene directory under 'images/' subfolder.
251
+
252
+ Args:
253
+ frames (list of np.ndarray): List of frames (BGR images) to save.
254
+ scene_dir (str): Target path where 'images/' subfolder will be created.
255
+ """
256
+ images_dir = os.path.join(scene_dir, "images")
257
+ os.makedirs(images_dir, exist_ok=True)
258
+
259
+ for idx, frame in enumerate(frames):
260
+ filename = os.path.join(images_dir, f"{idx:08d}.png") # 00000000.png, 00000001.png, etc.
261
+ cv2.imwrite(filename, frame)
262
+
263
+ print(f"Saved {len(frames)} frames to {images_dir}")
264
+
265
+
266
+ def run_colmap_on_scene(scene_dir):
267
+ """
268
+ Runs feature extraction, matching, and mapping on all images inside scene_dir/images using pycolmap.
269
+
270
+ Args:
271
+ scene_dir (str): Path to scene directory containing 'images' folder.
272
+
273
+ TODO: if the function hasn't managed to match all the frames either increase image size,
274
+ increase number of features or just remove those frames from the folder scene_dir/images
275
+ """
276
+ start_time = time.time()
277
+ print(f"Running COLMAP pipeline on all images inside {scene_dir}")
278
+
279
+ # Setup paths
280
+ database_path = os.path.join(scene_dir, "database.db")
281
+ sparse_path = os.path.join(scene_dir, "sparse")
282
+ image_dir = os.path.join(scene_dir, "images")
283
+
284
+ # Make sure output directories exist
285
+ os.makedirs(sparse_path, exist_ok=True)
286
+
287
+ # Step 1: Feature Extraction
288
+ pycolmap.extract_features(
289
+ database_path,
290
+ image_dir,
291
+ sift_options={
292
+ "max_num_features": 512 * 2,
293
+ "max_image_size": 512 * 1,
294
+ }
295
+ )
296
+ print(f"Finished feature extraction in {(time.time() - start_time):.2f}s.")
297
+
298
+ # Step 2: Feature Matching
299
+ pycolmap.match_exhaustive(database_path)
300
+ print(f"Finished feature matching in {(time.time() - start_time):.2f}s.")
301
+
302
+ # Step 3: Mapping
303
+ pipeline_options = pycolmap.IncrementalPipelineOptions()
304
+ pipeline_options.min_num_matches = 15
305
+ pipeline_options.multiple_models = True
306
+ pipeline_options.max_num_models = 50
307
+ pipeline_options.max_model_overlap = 20
308
+ pipeline_options.min_model_size = 10
309
+ pipeline_options.extract_colors = True
310
+ pipeline_options.num_threads = 8
311
+ pipeline_options.mapper.init_min_num_inliers = 30
312
+ pipeline_options.mapper.init_max_error = 8.0
313
+ pipeline_options.mapper.init_min_tri_angle = 5.0
314
+
315
+ reconstruction = pycolmap.incremental_mapping(
316
+ database_path=database_path,
317
+ image_path=image_dir,
318
+ output_path=sparse_path,
319
+ options=pipeline_options,
320
+ )
321
+ print(f"Finished incremental mapping in {(time.time() - start_time):.2f}s.")
322
+
323
+ # Step 4: Post-process Cameras to SIMPLE_PINHOLE
324
+ recon_path = os.path.join(sparse_path, "0")
325
+ reconstruction = pycolmap.Reconstruction(recon_path)
326
+
327
+ for cam in reconstruction.cameras.values():
328
+ cam.model = 'SIMPLE_PINHOLE'
329
+ cam.params = cam.params[:3] # Keep only [f, cx, cy]
330
+
331
+ reconstruction.write(recon_path)
332
+
333
+ print(f"Total pipeline time: {(time.time() - start_time):.2f}s.")
334
+
source/vggt_to_colmap.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reuse code taken from the implementation of atakan-topaloglu:
2
+ # https://github.com/atakan-topaloglu/vggt/blob/main/vggt_to_colmap.py
3
+
4
+ import os
5
+ import argparse
6
+ import numpy as np
7
+ import torch
8
+ import glob
9
+ import struct
10
+ from scipy.spatial.transform import Rotation
11
+ import sys
12
+ from PIL import Image
13
+ import cv2
14
+ import requests
15
+ import tempfile
16
+
17
+ sys.path.append("submodules/vggt/")
18
+
19
+ from vggt.models.vggt import VGGT
20
+ from vggt.utils.load_fn import load_and_preprocess_images
21
+ from vggt.utils.pose_enc import pose_encoding_to_extri_intri
22
+ from vggt.utils.geometry import unproject_depth_map_to_point_map
23
+
24
+ def load_model(device=None):
25
+ """Load and initialize the VGGT model."""
26
+ if device is None:
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ print(f"Using device: {device}")
29
+
30
+ model = VGGT.from_pretrained("facebook/VGGT-1B")
31
+
32
+ # model = VGGT()
33
+ # _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
34
+ # model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
35
+
36
+ model.eval()
37
+ model = model.to(device)
38
+ return model, device
39
+
40
+ def process_images(image_dir, model, device):
41
+ """Process images with VGGT and return predictions."""
42
+ image_names = glob.glob(os.path.join(image_dir, "*"))
43
+ image_names = sorted([f for f in image_names if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
44
+ print(f"Found {len(image_names)} images")
45
+
46
+ if len(image_names) == 0:
47
+ raise ValueError(f"No images found in {image_dir}")
48
+
49
+ original_images = []
50
+ for img_path in image_names:
51
+ img = Image.open(img_path).convert('RGB')
52
+ original_images.append(np.array(img))
53
+
54
+ images = load_and_preprocess_images(image_names).to(device)
55
+ print(f"Preprocessed images shape: {images.shape}")
56
+
57
+ print("Running inference...")
58
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
59
+
60
+ with torch.no_grad():
61
+ with torch.cuda.amp.autocast(dtype=dtype):
62
+ predictions = model(images)
63
+
64
+ print("Converting pose encoding to camera parameters...")
65
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
66
+ predictions["extrinsic"] = extrinsic
67
+ predictions["intrinsic"] = intrinsic
68
+
69
+ for key in predictions.keys():
70
+ if isinstance(predictions[key], torch.Tensor):
71
+ predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension
72
+
73
+ print("Computing 3D points from depth maps...")
74
+ depth_map = predictions["depth"] # (S, H, W, 1)
75
+ world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"])
76
+ predictions["world_points_from_depth"] = world_points
77
+
78
+ predictions["original_images"] = original_images
79
+
80
+ S, H, W = world_points.shape[:3]
81
+ normalized_images = np.zeros((S, H, W, 3), dtype=np.float32)
82
+
83
+ for i, img in enumerate(original_images):
84
+ resized_img = cv2.resize(img, (W, H))
85
+ normalized_images[i] = resized_img / 255.0
86
+
87
+ predictions["images"] = normalized_images
88
+
89
+ return predictions, image_names
90
+
91
+ def extrinsic_to_colmap_format(extrinsics):
92
+ """Convert extrinsic matrices to COLMAP format (quaternion + translation)."""
93
+ num_cameras = extrinsics.shape[0]
94
+ quaternions = []
95
+ translations = []
96
+
97
+ for i in range(num_cameras):
98
+ # VGGT's extrinsic is camera-to-world (R|t) format
99
+ R = extrinsics[i, :3, :3]
100
+ t = extrinsics[i, :3, 3]
101
+
102
+ # Convert rotation matrix to quaternion
103
+ # COLMAP quaternion format is [qw, qx, qy, qz]
104
+ rot = Rotation.from_matrix(R)
105
+ quat = rot.as_quat() # scipy returns [x, y, z, w]
106
+ quat = np.array([quat[3], quat[0], quat[1], quat[2]]) # Convert to [w, x, y, z]
107
+
108
+ quaternions.append(quat)
109
+ translations.append(t)
110
+
111
+ return np.array(quaternions), np.array(translations)
112
+
113
+ def download_file_from_url(url, filename):
114
+ """Downloads a file from a URL, handling redirects."""
115
+ try:
116
+ response = requests.get(url, allow_redirects=False)
117
+ response.raise_for_status()
118
+
119
+ if response.status_code == 302:
120
+ redirect_url = response.headers["Location"]
121
+ response = requests.get(redirect_url, stream=True)
122
+ response.raise_for_status()
123
+ else:
124
+ response = requests.get(url, stream=True)
125
+ response.raise_for_status()
126
+
127
+ with open(filename, "wb") as f:
128
+ for chunk in response.iter_content(chunk_size=8192):
129
+ f.write(chunk)
130
+ print(f"Downloaded {filename} successfully.")
131
+ return True
132
+
133
+ except requests.exceptions.RequestException as e:
134
+ print(f"Error downloading file: {e}")
135
+ return False
136
+
137
+ def segment_sky(image_path, onnx_session, mask_filename=None):
138
+ """
139
+ Segments sky from an image using an ONNX model.
140
+ """
141
+ image = cv2.imread(image_path)
142
+
143
+ result_map = run_skyseg(onnx_session, [320, 320], image)
144
+ result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0]))
145
+
146
+ # Fix: Invert the mask so that 255 = non-sky, 0 = sky
147
+ # The model outputs low values for sky, high values for non-sky
148
+ output_mask = np.zeros_like(result_map_original)
149
+ output_mask[result_map_original < 32] = 255 # Use threshold of 32
150
+
151
+ if mask_filename is not None:
152
+ os.makedirs(os.path.dirname(mask_filename), exist_ok=True)
153
+ cv2.imwrite(mask_filename, output_mask)
154
+
155
+ return output_mask
156
+
157
+ def run_skyseg(onnx_session, input_size, image):
158
+ """
159
+ Runs sky segmentation inference using ONNX model.
160
+ """
161
+ import copy
162
+
163
+ temp_image = copy.deepcopy(image)
164
+ resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1]))
165
+ x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB)
166
+ x = np.array(x, dtype=np.float32)
167
+ mean = [0.485, 0.456, 0.406]
168
+ std = [0.229, 0.224, 0.225]
169
+ x = (x / 255 - mean) / std
170
+ x = x.transpose(2, 0, 1)
171
+ x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32")
172
+
173
+ input_name = onnx_session.get_inputs()[0].name
174
+ output_name = onnx_session.get_outputs()[0].name
175
+ onnx_result = onnx_session.run([output_name], {input_name: x})
176
+
177
+ onnx_result = np.array(onnx_result).squeeze()
178
+ min_value = np.min(onnx_result)
179
+ max_value = np.max(onnx_result)
180
+ onnx_result = (onnx_result - min_value) / (max_value - min_value)
181
+ onnx_result *= 255
182
+ onnx_result = onnx_result.astype("uint8")
183
+
184
+ return onnx_result
185
+
186
+ def filter_and_prepare_points(predictions, conf_threshold, mask_sky=False, mask_black_bg=False,
187
+ mask_white_bg=False, stride=1, prediction_mode="Depthmap and Camera Branch"):
188
+ """
189
+ Filter points based on confidence and prepare for COLMAP format.
190
+ Implementation matches the conventions in the original VGGT code.
191
+ """
192
+
193
+ if "Pointmap" in prediction_mode:
194
+ print("Using Pointmap Branch")
195
+ if "world_points" in predictions:
196
+ pred_world_points = predictions["world_points"]
197
+ pred_world_points_conf = predictions.get("world_points_conf", np.ones_like(pred_world_points[..., 0]))
198
+ else:
199
+ print("Warning: world_points not found in predictions, falling back to depth-based points")
200
+ pred_world_points = predictions["world_points_from_depth"]
201
+ pred_world_points_conf = predictions.get("depth_conf", np.ones_like(pred_world_points[..., 0]))
202
+ else:
203
+ print("Using Depthmap and Camera Branch")
204
+ pred_world_points = predictions["world_points_from_depth"]
205
+ pred_world_points_conf = predictions.get("depth_conf", np.ones_like(pred_world_points[..., 0]))
206
+
207
+ colors_rgb = predictions["images"]
208
+
209
+ S, H, W = pred_world_points.shape[:3]
210
+ if colors_rgb.shape[:3] != (S, H, W):
211
+ print(f"Reshaping colors_rgb from {colors_rgb.shape} to match {(S, H, W, 3)}")
212
+ reshaped_colors = np.zeros((S, H, W, 3), dtype=np.float32)
213
+ for i in range(S):
214
+ if i < len(colors_rgb):
215
+ reshaped_colors[i] = cv2.resize(colors_rgb[i], (W, H))
216
+ colors_rgb = reshaped_colors
217
+
218
+ colors_rgb = (colors_rgb * 255).astype(np.uint8)
219
+
220
+ if mask_sky:
221
+ print("Applying sky segmentation mask")
222
+ try:
223
+ import onnxruntime
224
+
225
+ with tempfile.TemporaryDirectory() as temp_dir:
226
+ print(f"Created temporary directory for sky segmentation: {temp_dir}")
227
+ temp_images_dir = os.path.join(temp_dir, "images")
228
+ sky_masks_dir = os.path.join(temp_dir, "sky_masks")
229
+ os.makedirs(temp_images_dir, exist_ok=True)
230
+ os.makedirs(sky_masks_dir, exist_ok=True)
231
+
232
+ image_list = []
233
+ for i, img in enumerate(colors_rgb):
234
+ img_path = os.path.join(temp_images_dir, f"image_{i:04d}.png")
235
+ image_list.append(img_path)
236
+ cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
237
+
238
+
239
+ skyseg_path = os.path.join(temp_dir, "skyseg.onnx")
240
+ if not os.path.exists("skyseg.onnx"):
241
+ print("Downloading skyseg.onnx...")
242
+ download_success = download_file_from_url(
243
+ "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx",
244
+ skyseg_path
245
+ )
246
+ if not download_success:
247
+ print("Failed to download skyseg model, skipping sky filtering")
248
+ mask_sky = False
249
+ else:
250
+
251
+ import shutil
252
+ shutil.copy("skyseg.onnx", skyseg_path)
253
+
254
+ if mask_sky:
255
+ skyseg_session = onnxruntime.InferenceSession(skyseg_path)
256
+ sky_mask_list = []
257
+
258
+ for img_path in image_list:
259
+ mask_path = os.path.join(sky_masks_dir, os.path.basename(img_path))
260
+ sky_mask = segment_sky(img_path, skyseg_session, mask_path)
261
+
262
+ if sky_mask.shape[0] != H or sky_mask.shape[1] != W:
263
+ sky_mask = cv2.resize(sky_mask, (W, H))
264
+
265
+ sky_mask_list.append(sky_mask)
266
+
267
+ sky_mask_array = np.array(sky_mask_list)
268
+
269
+ sky_mask_binary = (sky_mask_array > 0.1).astype(np.float32)
270
+ pred_world_points_conf = pred_world_points_conf * sky_mask_binary
271
+ print(f"Applied sky mask, shape: {sky_mask_binary.shape}")
272
+
273
+ except (ImportError, Exception) as e:
274
+ print(f"Error in sky segmentation: {e}")
275
+ mask_sky = False
276
+
277
+ vertices_3d = pred_world_points.reshape(-1, 3)
278
+ conf = pred_world_points_conf.reshape(-1)
279
+ colors_rgb_flat = colors_rgb.reshape(-1, 3)
280
+
281
+
282
+
283
+ if len(conf) != len(colors_rgb_flat):
284
+ print(f"WARNING: Shape mismatch between confidence ({len(conf)}) and colors ({len(colors_rgb_flat)})")
285
+ min_size = min(len(conf), len(colors_rgb_flat))
286
+ conf = conf[:min_size]
287
+ vertices_3d = vertices_3d[:min_size]
288
+ colors_rgb_flat = colors_rgb_flat[:min_size]
289
+
290
+ if conf_threshold == 0.0:
291
+ conf_thres_value = 0.0
292
+ else:
293
+ conf_thres_value = np.percentile(conf, conf_threshold)
294
+
295
+ print(f"Using confidence threshold: {conf_threshold}% (value: {conf_thres_value:.4f})")
296
+ conf_mask = (conf >= conf_thres_value) & (conf > 1e-5)
297
+
298
+ if mask_black_bg:
299
+ print("Filtering black background")
300
+ black_bg_mask = colors_rgb_flat.sum(axis=1) >= 16
301
+ conf_mask = conf_mask & black_bg_mask
302
+
303
+ if mask_white_bg:
304
+ print("Filtering white background")
305
+ white_bg_mask = ~((colors_rgb_flat[:, 0] > 240) & (colors_rgb_flat[:, 1] > 240) & (colors_rgb_flat[:, 2] > 240))
306
+ conf_mask = conf_mask & white_bg_mask
307
+
308
+ filtered_vertices = vertices_3d[conf_mask]
309
+ filtered_colors = colors_rgb_flat[conf_mask]
310
+
311
+ if len(filtered_vertices) == 0:
312
+ print("Warning: No points remaining after filtering. Using default point.")
313
+ filtered_vertices = np.array([[0, 0, 0]])
314
+ filtered_colors = np.array([[200, 200, 200]])
315
+
316
+ print(f"Filtered to {len(filtered_vertices)} points")
317
+
318
+ points3D = []
319
+ point_indices = {}
320
+ image_points2D = [[] for _ in range(len(pred_world_points))]
321
+
322
+ print(f"Preparing points for COLMAP format with stride {stride}...")
323
+
324
+ total_points = 0
325
+ for img_idx in range(S):
326
+ for y in range(0, H, stride):
327
+ for x in range(0, W, stride):
328
+ flat_idx = img_idx * H * W + y * W + x
329
+
330
+ if flat_idx >= len(conf):
331
+ continue
332
+
333
+ if conf[flat_idx] < conf_thres_value or conf[flat_idx] <= 1e-5:
334
+ continue
335
+
336
+ if mask_black_bg and colors_rgb_flat[flat_idx].sum() < 16:
337
+ continue
338
+
339
+ if mask_white_bg and all(colors_rgb_flat[flat_idx] > 240):
340
+ continue
341
+
342
+ point3D = vertices_3d[flat_idx]
343
+ rgb = colors_rgb_flat[flat_idx]
344
+
345
+ if not np.all(np.isfinite(point3D)):
346
+ continue
347
+
348
+ point_hash = hash_point(point3D, scale=100)
349
+
350
+ if point_hash not in point_indices:
351
+ point_idx = len(points3D)
352
+ point_indices[point_hash] = point_idx
353
+
354
+ point_entry = {
355
+ "id": point_idx,
356
+ "xyz": point3D,
357
+ "rgb": rgb,
358
+ "error": 1.0,
359
+ "track": [(img_idx, len(image_points2D[img_idx]))]
360
+ }
361
+ points3D.append(point_entry)
362
+ total_points += 1
363
+ else:
364
+ point_idx = point_indices[point_hash]
365
+ points3D[point_idx]["track"].append((img_idx, len(image_points2D[img_idx])))
366
+
367
+ image_points2D[img_idx].append((x, y, point_indices[point_hash]))
368
+
369
+ print(f"Prepared {len(points3D)} 3D points with {sum(len(pts) for pts in image_points2D)} observations for COLMAP")
370
+ return points3D, image_points2D
371
+
372
+ def hash_point(point, scale=100):
373
+ """Create a hash for a 3D point by quantizing coordinates."""
374
+ quantized = tuple(np.round(point * scale).astype(int))
375
+ return hash(quantized)
376
+
377
+ def write_colmap_cameras_txt(file_path, intrinsics, image_width, image_height):
378
+ """Write camera intrinsics to COLMAP cameras.txt format."""
379
+ with open(file_path, 'w') as f:
380
+ f.write("# Camera list with one line of data per camera:\n")
381
+ f.write("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n")
382
+ f.write(f"# Number of cameras: {len(intrinsics)}\n")
383
+
384
+ for i, intrinsic in enumerate(intrinsics):
385
+ camera_id = i + 1 # COLMAP uses 1-indexed camera IDs
386
+ model = "PINHOLE"
387
+
388
+ fx = intrinsic[0, 0]
389
+ fy = intrinsic[1, 1]
390
+ cx = intrinsic[0, 2]
391
+ cy = intrinsic[1, 2]
392
+
393
+ f.write(f"{camera_id} {model} {image_width} {image_height} {fx} {fy} {cx} {cy}\n")
394
+
395
+ def write_colmap_images_txt(file_path, quaternions, translations, image_points2D, image_names):
396
+ """Write camera poses and keypoints to COLMAP images.txt format."""
397
+ with open(file_path, 'w') as f:
398
+ f.write("# Image list with two lines of data per image:\n")
399
+ f.write("# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n")
400
+ f.write("# POINTS2D[] as (X, Y, POINT3D_ID)\n")
401
+
402
+ num_points = sum(len(points) for points in image_points2D)
403
+ avg_points = num_points / len(image_points2D) if image_points2D else 0
404
+ f.write(f"# Number of images: {len(quaternions)}, mean observations per image: {avg_points:.1f}\n")
405
+
406
+ for i in range(len(quaternions)):
407
+ image_id = i + 1
408
+ camera_id = i + 1
409
+
410
+ qw, qx, qy, qz = quaternions[i]
411
+ tx, ty, tz = translations[i]
412
+
413
+ f.write(f"{image_id} {qw} {qx} {qy} {qz} {tx} {ty} {tz} {camera_id} {os.path.basename(image_names[i])}\n")
414
+
415
+ points_line = " ".join([f"{x} {y} {point3d_id+1}" for x, y, point3d_id in image_points2D[i]])
416
+ f.write(f"{points_line}\n")
417
+
418
+ def write_colmap_points3D_txt(file_path, points3D):
419
+ """Write 3D points and tracks to COLMAP points3D.txt format."""
420
+ with open(file_path, 'w') as f:
421
+ f.write("# 3D point list with one line of data per point:\n")
422
+ f.write("# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n")
423
+
424
+ avg_track_length = sum(len(point["track"]) for point in points3D) / len(points3D) if points3D else 0
425
+ f.write(f"# Number of points: {len(points3D)}, mean track length: {avg_track_length:.4f}\n")
426
+
427
+ for point in points3D:
428
+ point_id = point["id"] + 1
429
+ x, y, z = point["xyz"]
430
+ r, g, b = point["rgb"]
431
+ error = point["error"]
432
+
433
+ track = " ".join([f"{img_id+1} {point2d_idx}" for img_id, point2d_idx in point["track"]])
434
+
435
+ f.write(f"{point_id} {x} {y} {z} {int(r)} {int(g)} {int(b)} {error} {track}\n")
436
+
437
+ def write_colmap_cameras_bin(file_path, intrinsics, image_width, image_height):
438
+ """Write camera intrinsics to COLMAP cameras.bin format."""
439
+ with open(file_path, 'wb') as fid:
440
+ # Write number of cameras (uint64)
441
+ fid.write(struct.pack('<Q', len(intrinsics)))
442
+
443
+ for i, intrinsic in enumerate(intrinsics):
444
+ camera_id = i + 1
445
+ model_id = 1
446
+
447
+ fx = float(intrinsic[0, 0])
448
+ fy = float(intrinsic[1, 1])
449
+ cx = float(intrinsic[0, 2])
450
+ cy = float(intrinsic[1, 2])
451
+
452
+ # Camera ID (uint32)
453
+ fid.write(struct.pack('<I', camera_id))
454
+ # Model ID (uint32)
455
+ fid.write(struct.pack('<I', model_id))
456
+ # Width (uint64)
457
+ fid.write(struct.pack('<Q', image_width))
458
+ # Height (uint64)
459
+ fid.write(struct.pack('<Q', image_height))
460
+
461
+ # Parameters (double)
462
+ fid.write(struct.pack('<dddd', fx, fy, cx, cy))
463
+
464
+ def write_colmap_images_bin(file_path, quaternions, translations, image_points2D, image_names):
465
+ """Write camera poses and keypoints to COLMAP images.bin format."""
466
+ with open(file_path, 'wb') as fid:
467
+ # Write number of images (uint64)
468
+ fid.write(struct.pack('<Q', len(quaternions)))
469
+
470
+ for i in range(len(quaternions)):
471
+ image_id = i + 1
472
+ camera_id = i + 1
473
+
474
+ qw, qx, qy, qz = quaternions[i].astype(float)
475
+ tx, ty, tz = translations[i].astype(float)
476
+
477
+ image_name = os.path.basename(image_names[i]).encode()
478
+ points = image_points2D[i]
479
+
480
+ # Image ID (uint32)
481
+ fid.write(struct.pack('<I', image_id))
482
+ # Quaternion (double): qw, qx, qy, qz
483
+ fid.write(struct.pack('<dddd', qw, qx, qy, qz))
484
+ # Translation (double): tx, ty, tz
485
+ fid.write(struct.pack('<ddd', tx, ty, tz))
486
+ # Camera ID (uint32)
487
+ fid.write(struct.pack('<I', camera_id))
488
+ # Image name
489
+ fid.write(struct.pack('<I', len(image_name)))
490
+ fid.write(image_name)
491
+
492
+ # Write number of 2D points (uint64)
493
+ fid.write(struct.pack('<Q', len(points)))
494
+
495
+ # Write 2D points: x, y, point3D_id
496
+ for x, y, point3d_id in points:
497
+ fid.write(struct.pack('<dd', float(x), float(y)))
498
+ fid.write(struct.pack('<Q', point3d_id + 1))
499
+
500
+ def write_colmap_points3D_bin(file_path, points3D):
501
+ """Write 3D points and tracks to COLMAP points3D.bin format."""
502
+ with open(file_path, 'wb') as fid:
503
+ # Write number of points (uint64)
504
+ fid.write(struct.pack('<Q', len(points3D)))
505
+
506
+ for point in points3D:
507
+ point_id = point["id"] + 1
508
+ x, y, z = point["xyz"].astype(float)
509
+ r, g, b = point["rgb"].astype(np.uint8)
510
+ error = float(point["error"])
511
+ track = point["track"]
512
+
513
+ # Point ID (uint64)
514
+ fid.write(struct.pack('<Q', point_id))
515
+ # Position (double): x, y, z
516
+ fid.write(struct.pack('<ddd', x, y, z))
517
+ # Color (uint8): r, g, b
518
+ fid.write(struct.pack('<BBB', int(r), int(g), int(b)))
519
+ # Error (double)
520
+ fid.write(struct.pack('<d', error))
521
+
522
+ # Track: list of (image_id, point2D_idx)
523
+ fid.write(struct.pack('<Q', len(track)))
524
+ for img_id, point2d_idx in track:
525
+ fid.write(struct.pack('<II', img_id + 1, point2d_idx))
526
+
527
+ def main():
528
+ parser = argparse.ArgumentParser(description="Convert images to COLMAP format using VGGT")
529
+ parser.add_argument("--image_dir", type=str, required=True,
530
+ help="Directory containing input images")
531
+ parser.add_argument("--output_dir", type=str, default="colmap_output",
532
+ help="Directory to save COLMAP files")
533
+ parser.add_argument("--conf_threshold", type=float, default=50.0,
534
+ help="Confidence threshold (0-100%) for including points")
535
+ parser.add_argument("--mask_sky", action="store_true",
536
+ help="Filter out points likely to be sky")
537
+ parser.add_argument("--mask_black_bg", action="store_true",
538
+ help="Filter out points with very dark/black color")
539
+ parser.add_argument("--mask_white_bg", action="store_true",
540
+ help="Filter out points with very bright/white color")
541
+ parser.add_argument("--binary", action="store_true",
542
+ help="Output binary COLMAP files instead of text")
543
+ parser.add_argument("--stride", type=int, default=1,
544
+ help="Stride for point sampling (higher = fewer points)")
545
+ parser.add_argument("--prediction_mode", type=str, default="Depthmap and Camera Branch",
546
+ choices=["Depthmap and Camera Branch", "Pointmap Branch"],
547
+ help="Which prediction branch to use")
548
+
549
+ args = parser.parse_args()
550
+
551
+ os.makedirs(args.output_dir, exist_ok=True)
552
+
553
+ model, device = load_model()
554
+
555
+ predictions, image_names = process_images(args.image_dir, model, device)
556
+
557
+ print("Converting camera parameters to COLMAP format...")
558
+ quaternions, translations = extrinsic_to_colmap_format(predictions["extrinsic"])
559
+
560
+ print(f"Filtering points with confidence threshold {args.conf_threshold}% and stride {args.stride}...")
561
+ points3D, image_points2D = filter_and_prepare_points(
562
+ predictions,
563
+ args.conf_threshold,
564
+ mask_sky=args.mask_sky,
565
+ mask_black_bg=args.mask_black_bg,
566
+ mask_white_bg=args.mask_white_bg,
567
+ stride=args.stride,
568
+ prediction_mode=args.prediction_mode
569
+ )
570
+
571
+ height, width = predictions["depth"].shape[1:3]
572
+
573
+ print(f"Writing {'binary' if args.binary else 'text'} COLMAP files to {args.output_dir}...")
574
+ if args.binary:
575
+ write_colmap_cameras_bin(
576
+ os.path.join(args.output_dir, "cameras.bin"),
577
+ predictions["intrinsic"], width, height)
578
+ write_colmap_images_bin(
579
+ os.path.join(args.output_dir, "images.bin"),
580
+ quaternions, translations, image_points2D, image_names)
581
+ write_colmap_points3D_bin(
582
+ os.path.join(args.output_dir, "points3D.bin"),
583
+ points3D)
584
+ else:
585
+ write_colmap_cameras_txt(
586
+ os.path.join(args.output_dir, "cameras.txt"),
587
+ predictions["intrinsic"], width, height)
588
+ write_colmap_images_txt(
589
+ os.path.join(args.output_dir, "images.txt"),
590
+ quaternions, translations, image_points2D, image_names)
591
+ write_colmap_points3D_txt(
592
+ os.path.join(args.output_dir, "points3D.txt"),
593
+ points3D)
594
+
595
+ print(f"COLMAP files successfully written to {args.output_dir}")
596
+
597
+ if __name__ == "__main__":
598
+ main()
source/visualization.py ADDED
@@ -0,0 +1,1072 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from matplotlib import pyplot as plt
2
+ import numpy as np
3
+ import torch
4
+
5
+ import numpy as np
6
+ from typing import List
7
+ import sys
8
+ sys.path.append('./submodules/gaussian-splatting/')
9
+ from scene.cameras import Camera
10
+ from PIL import Image
11
+ import imageio
12
+ from scipy.interpolate import splprep, splev
13
+
14
+ import cv2
15
+ import numpy as np
16
+ import plotly.graph_objects as go
17
+ import numpy as np
18
+ from scipy.spatial.transform import Rotation as R, Slerp
19
+ from scipy.spatial import distance_matrix
20
+ from sklearn.decomposition import PCA
21
+ from scipy.interpolate import splprep, splev
22
+ from typing import List
23
+ from sklearn.mixture import GaussianMixture
24
+
25
+ def render_gaussians_rgb(generator3DGS, viewpoint_cam, visualize=False):
26
+ """
27
+ Simply render gaussians from the generator3DGS from the viewpoint_cam.
28
+ Args:
29
+ generator3DGS : instance of the Generator3DGS class from the networks.py file
30
+ viewpoint_cam : camera instance
31
+ visualize : boolean flag. If True, will call pyplot function and render image inplace
32
+ Returns:
33
+ uint8 numpy array with shape (H, W, 3) representing the image
34
+ """
35
+ with torch.no_grad():
36
+ render_pkg = generator3DGS(viewpoint_cam)
37
+ image = render_pkg["render"]
38
+ image_np = image.clone().detach().cpu().numpy().transpose(1, 2, 0)
39
+
40
+ # Clip values to be in the range [0, 1]
41
+ image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8)
42
+ if visualize:
43
+ plt.figure(figsize=(12, 8))
44
+ plt.imshow(image_np)
45
+ plt.show()
46
+
47
+ return image_np
48
+
49
+ def render_gaussians_D_scores(generator3DGS, viewpoint_cam, mask=None, mask_channel=0, visualize=False):
50
+ """
51
+ Simply render D_scores of gaussians from the generator3DGS from the viewpoint_cam.
52
+ Args:
53
+ generator3DGS : instance of the Generator3DGS class from the networks.py file
54
+ viewpoint_cam : camera instance
55
+ visualize : boolean flag. If True, will call pyplot function and render image inplace
56
+ mask : optional mask to highlight specific gaussians. Must be of shape (N) where N is the numnber
57
+ of gaussians in generator3DGS.gaussians. Must be a torch tensor of floats, please scale according
58
+ to how much color you want to have. Recommended mask value is 10.
59
+ mask_channel: to which color channel should we add mask
60
+ Returns:
61
+ uint8 numpy array with shape (H, W, 3) representing the generator3DGS.gaussians.D_scores rendered as colors
62
+ """
63
+ with torch.no_grad():
64
+ # Visualize D_scores
65
+ generator3DGS.gaussians._features_dc = generator3DGS.gaussians._features_dc * 1e-4 + \
66
+ torch.stack([generator3DGS.gaussians.D_scores] * 3, axis=-1)
67
+ generator3DGS.gaussians._features_rest = generator3DGS.gaussians._features_rest * 1e-4
68
+ if mask is not None:
69
+ generator3DGS.gaussians._features_dc[..., mask_channel] += mask.unsqueeze(-1)
70
+ render_pkg = generator3DGS(viewpoint_cam)
71
+ image = render_pkg["render"]
72
+ image_np = image.clone().detach().cpu().numpy().transpose(1, 2, 0)
73
+
74
+ # Clip values to be in the range [0, 1]
75
+ image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8)
76
+ if visualize:
77
+ plt.figure(figsize=(12, 8))
78
+ plt.imshow(image_np)
79
+ plt.show()
80
+
81
+ if mask is not None:
82
+ generator3DGS.gaussians._features_dc[..., mask_channel] -= mask.unsqueeze(-1)
83
+
84
+ generator3DGS.gaussians._features_dc = (generator3DGS.gaussians._features_dc - \
85
+ torch.stack([generator3DGS.gaussians.D_scores] * 3, axis=-1)) * 1e4
86
+ generator3DGS.gaussians._features_rest = generator3DGS.gaussians._features_rest * 1e4
87
+
88
+ return image_np
89
+
90
+
91
+
92
+ def normalize(v):
93
+ """
94
+ Normalize a vector to unit length.
95
+
96
+ Parameters:
97
+ v (np.ndarray): Input vector.
98
+
99
+ Returns:
100
+ np.ndarray: Unit vector in the same direction as `v`.
101
+ """
102
+ return v / np.linalg.norm(v)
103
+
104
+ def look_at_rotation(camera_position: np.ndarray, target: np.ndarray, world_up=np.array([0, 1, 0])):
105
+ """
106
+ Compute a rotation matrix for a camera looking at a target point.
107
+
108
+ Parameters:
109
+ camera_position (np.ndarray): The 3D position of the camera.
110
+ target (np.ndarray): The point the camera should look at.
111
+ world_up (np.ndarray): A vector that defines the global 'up' direction.
112
+
113
+ Returns:
114
+ np.ndarray: A 3x3 rotation matrix (camera-to-world) with columns [right, up, forward].
115
+ """
116
+ z_axis = normalize(target - camera_position) # Forward direction
117
+ x_axis = normalize(np.cross(world_up, z_axis)) # Right direction
118
+ y_axis = np.cross(z_axis, x_axis) # Recomputed up
119
+ return np.stack([x_axis, y_axis, z_axis], axis=1)
120
+
121
+
122
+ def generate_circular_camera_path(existing_cameras: List[Camera], N: int = 12, radius_scale: float = 1.0, d: float = 2.0) -> List[Camera]:
123
+ """
124
+ Generate a circular path of cameras around an existing camera group,
125
+ with each new camera oriented to look at the average viewing direction.
126
+
127
+ Parameters:
128
+ existing_cameras (List[Camera]): List of existing camera objects to estimate average orientation and layout.
129
+ N (int): Number of new cameras to generate along the circular path.
130
+ radius_scale (float): Scale factor to adjust the radius of the circle.
131
+ d (float): Distance ahead of each camera used to estimate its look-at point.
132
+
133
+ Returns:
134
+ List[Camera]: A list of newly generated Camera objects forming a circular path and oriented toward a shared view center.
135
+ """
136
+ # Step 1: Compute average camera position
137
+ center = np.mean([cam.T for cam in existing_cameras], axis=0)
138
+
139
+ # Estimate where each camera is looking
140
+ # d denotes how far ahead each camera sees — you can scale this
141
+ look_targets = [cam.T + cam.R[:, 2] * d for cam in existing_cameras]
142
+ center_of_view = np.mean(look_targets, axis=0)
143
+
144
+ # Step 2: Define circular plane basis using fixed up vector
145
+ avg_forward = normalize(np.mean([cam.R[:, 2] for cam in existing_cameras], axis=0))
146
+ up_guess = np.array([0, 1, 0])
147
+ right = normalize(np.cross(avg_forward, up_guess))
148
+ up = normalize(np.cross(right, avg_forward))
149
+
150
+ # Step 3: Estimate radius
151
+ avg_radius = np.mean([np.linalg.norm(cam.T - center) for cam in existing_cameras]) * radius_scale
152
+
153
+ # Step 4: Create cameras on a circular path
154
+ angles = np.linspace(0, 2 * np.pi, N, endpoint=False)
155
+ reference_cam = existing_cameras[0]
156
+ new_cameras = []
157
+
158
+
159
+ for i, a in enumerate(angles):
160
+ position = center + avg_radius * (np.cos(a) * right + np.sin(a) * up)
161
+
162
+ if d < 1e-5 or radius_scale < 1e-5:
163
+ # Use same orientation as the first camera
164
+ R = reference_cam.R.copy()
165
+ else:
166
+ # Change orientation
167
+ R = look_at_rotation(position, center_of_view)
168
+ new_cameras.append(Camera(
169
+ R=R,
170
+ T=position, # New position
171
+ FoVx=reference_cam.FoVx,
172
+ FoVy=reference_cam.FoVy,
173
+ resolution=(reference_cam.image_width, reference_cam.image_height),
174
+ colmap_id=-1,
175
+ depth_params=None,
176
+ image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
177
+ invdepthmap=None,
178
+ image_name=f"circular_a={a:.3f}",
179
+ uid=i
180
+ ))
181
+
182
+ return new_cameras
183
+
184
+
185
+ def save_numpy_frames_as_gif(frames, output_path="animation.gif", duration=100):
186
+ """
187
+ Save a list of RGB NumPy frames as a looping GIF animation.
188
+
189
+ Parameters:
190
+ frames (List[np.ndarray]): List of RGB images as uint8 NumPy arrays (shape HxWx3).
191
+ output_path (str): Path to save the output GIF.
192
+ duration (int): Duration per frame in milliseconds.
193
+
194
+ Returns:
195
+ None
196
+ """
197
+ pil_frames = [Image.fromarray(f) for f in frames]
198
+ pil_frames[0].save(
199
+ output_path,
200
+ save_all=True,
201
+ append_images=pil_frames[1:],
202
+ duration=duration, # duration per frame in ms
203
+ loop=0
204
+ )
205
+ print(f"GIF saved to: {output_path}")
206
+
207
+ def center_crop_frame(frame: np.ndarray, crop_fraction: float) -> np.ndarray:
208
+ """
209
+ Crop the central region of the frame by the given fraction.
210
+
211
+ Parameters:
212
+ frame (np.ndarray): Input RGB image (H, W, 3).
213
+ crop_fraction (float): Fraction of the original size to retain (e.g., 0.8 keeps 80%).
214
+
215
+ Returns:
216
+ np.ndarray: Cropped RGB image.
217
+ """
218
+ if crop_fraction >= 1.0:
219
+ return frame
220
+
221
+ h, w, _ = frame.shape
222
+ new_h, new_w = int(h * crop_fraction), int(w * crop_fraction)
223
+ start_y = (h - new_h) // 2
224
+ start_x = (w - new_w) // 2
225
+ return frame[start_y:start_y + new_h, start_x:start_x + new_w, :]
226
+
227
+
228
+
229
+ def generate_smooth_closed_camera_path(existing_cameras: List[Camera], N: int = 120, d: float = 2.0, s=.25) -> List[Camera]:
230
+ """
231
+ Generate a smooth, closed path interpolating the positions of existing cameras.
232
+
233
+ Parameters:
234
+ existing_cameras (List[Camera]): List of existing cameras.
235
+ N (int): Number of points (cameras) to sample along the smooth path.
236
+ d (float): Distance ahead for estimating the center of view.
237
+
238
+ Returns:
239
+ List[Camera]: A list of smoothly moving Camera objects along a closed loop.
240
+ """
241
+ # Step 1: Extract camera positions
242
+ positions = np.array([cam.T for cam in existing_cameras])
243
+
244
+ # Step 2: Estimate center of view
245
+ look_targets = [cam.T + cam.R[:, 2] * d for cam in existing_cameras]
246
+ center_of_view = np.mean(look_targets, axis=0)
247
+
248
+ # Step 3: Fit a smooth closed spline through the positions
249
+ positions = np.vstack([positions, positions[0]]) # close the loop
250
+ tck, u = splprep(positions.T, s=s, per=True) # periodic=True for closed loop
251
+
252
+ # Step 4: Sample points along the spline
253
+ u_fine = np.linspace(0, 1, N)
254
+ smooth_path = np.stack(splev(u_fine, tck), axis=-1)
255
+
256
+ # Step 5: Generate cameras along the smooth path
257
+ reference_cam = existing_cameras[0]
258
+ new_cameras = []
259
+
260
+ for i, pos in enumerate(smooth_path):
261
+ R = look_at_rotation(pos, center_of_view)
262
+ new_cameras.append(Camera(
263
+ R=R,
264
+ T=pos,
265
+ FoVx=reference_cam.FoVx,
266
+ FoVy=reference_cam.FoVy,
267
+ resolution=(reference_cam.image_width, reference_cam.image_height),
268
+ colmap_id=-1,
269
+ depth_params=None,
270
+ image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
271
+ invdepthmap=None,
272
+ image_name=f"smooth_path_i={i}",
273
+ uid=i
274
+ ))
275
+
276
+ return new_cameras
277
+
278
+
279
+ def save_numpy_frames_as_mp4(frames, output_path="animation.mp4", fps=10, center_crop: float = 1.0):
280
+ """
281
+ Save a list of RGB NumPy frames as an MP4 video with optional center cropping.
282
+
283
+ Parameters:
284
+ frames (List[np.ndarray]): List of RGB images as uint8 NumPy arrays (shape HxWx3).
285
+ output_path (str): Path to save the output MP4.
286
+ fps (int): Frames per second for playback speed.
287
+ center_crop (float): Fraction (0 < center_crop <= 1.0) of central region to retain.
288
+ Use 1.0 for no cropping; 0.8 to crop to 80% center region.
289
+
290
+ Returns:
291
+ None
292
+ """
293
+ with imageio.get_writer(output_path, fps=fps, codec='libx264', quality=8) as writer:
294
+ for frame in frames:
295
+ cropped = center_crop_frame(frame, center_crop)
296
+ writer.append_data(cropped)
297
+ print(f"MP4 saved to: {output_path}")
298
+
299
+
300
+
301
+ def put_text_on_image(img: np.ndarray, text: str) -> np.ndarray:
302
+ """
303
+ Draws multiline white text on a copy of the input image, positioned near the bottom
304
+ and around 80% of the image width. Handles '\n' characters to split text into multiple lines.
305
+
306
+ Args:
307
+ img (np.ndarray): Input image as a (H, W, 3) uint8 numpy array.
308
+ text (str): Text string to draw on the image. Newlines '\n' are treated as line breaks.
309
+
310
+ Returns:
311
+ np.ndarray: The output image with the text drawn on it.
312
+
313
+ Notes:
314
+ - The function automatically adjusts line spacing and prevents text from going outside the image.
315
+ - Text is drawn in white with small font size (0.5) for minimal visual impact.
316
+ """
317
+ img = img.copy()
318
+ height, width, _ = img.shape
319
+
320
+ font = cv2.FONT_HERSHEY_SIMPLEX
321
+ font_scale = 1.
322
+ color = (255, 255, 255)
323
+ thickness = 2
324
+ line_spacing = 5 # extra pixels between lines
325
+
326
+ lines = text.split('\n')
327
+
328
+ # Precompute the maximum text width to adjust starting x
329
+ max_text_width = max(cv2.getTextSize(line, font, font_scale, thickness)[0][0] for line in lines)
330
+
331
+ x = int(0.8 * width)
332
+ x = min(x, width - max_text_width - 30) # margin on right
333
+ #x = int(0.03 * width)
334
+
335
+ # Start near the bottom, but move up depending on number of lines
336
+ total_text_height = len(lines) * (cv2.getTextSize('A', font, font_scale, thickness)[0][1] + line_spacing)
337
+ y_start = int(height*0.9) - total_text_height # 30 pixels from bottom
338
+
339
+ for i, line in enumerate(lines):
340
+ y = y_start + i * (cv2.getTextSize(line, font, font_scale, thickness)[0][1] + line_spacing)
341
+ cv2.putText(img, line, (x, y), font, font_scale, color, thickness, cv2.LINE_AA)
342
+
343
+ return img
344
+
345
+
346
+
347
+
348
+ def catmull_rom_spline(P0, P1, P2, P3, n_points=20):
349
+ """
350
+ Compute Catmull-Rom spline segment between P1 and P2.
351
+ """
352
+ t = np.linspace(0, 1, n_points)[:, None]
353
+
354
+ M = 0.5 * np.array([
355
+ [-1, 3, -3, 1],
356
+ [ 2, -5, 4, -1],
357
+ [-1, 0, 1, 0],
358
+ [ 0, 2, 0, 0]
359
+ ])
360
+
361
+ G = np.stack([P0, P1, P2, P3], axis=0)
362
+ T = np.concatenate([t**3, t**2, t, np.ones_like(t)], axis=1)
363
+
364
+ return T @ M @ G
365
+
366
+ def sort_cameras_pca(existing_cameras: List[Camera]):
367
+ """
368
+ Sort cameras along the main PCA axis.
369
+ """
370
+ positions = np.array([cam.T for cam in existing_cameras])
371
+ pca = PCA(n_components=1)
372
+ scores = pca.fit_transform(positions)
373
+ sorted_indices = np.argsort(scores[:, 0])
374
+ return sorted_indices
375
+
376
+ def generate_fully_smooth_cameras(existing_cameras: List[Camera],
377
+ n_selected: int = 30,
378
+ n_points_per_segment: int = 20,
379
+ d: float = 2.0,
380
+ closed: bool = False) -> List[Camera]:
381
+ """
382
+ Generate a fully smooth camera path using PCA ordering, global Catmull-Rom spline for positions, and global SLERP for orientations.
383
+
384
+ Args:
385
+ existing_cameras (List[Camera]): List of input cameras.
386
+ n_selected (int): Number of cameras to select after sorting.
387
+ n_points_per_segment (int): Number of interpolated points per spline segment.
388
+ d (float): Distance ahead for estimating center of view.
389
+ closed (bool): Whether to close the path.
390
+
391
+ Returns:
392
+ List[Camera]: List of smoothly moving Camera objects.
393
+ """
394
+ # 1. Sort cameras along PCA axis
395
+ sorted_indices = sort_cameras_pca(existing_cameras)
396
+ sorted_cameras = [existing_cameras[i] for i in sorted_indices]
397
+ positions = np.array([cam.T for cam in sorted_cameras])
398
+
399
+ # 2. Subsample uniformly
400
+ idx = np.linspace(0, len(positions) - 1, n_selected).astype(int)
401
+ sampled_positions = positions[idx]
402
+ sampled_cameras = [sorted_cameras[i] for i in idx]
403
+
404
+ # 3. Prepare for Catmull-Rom
405
+ if closed:
406
+ sampled_positions = np.vstack([sampled_positions[-1], sampled_positions, sampled_positions[0], sampled_positions[1]])
407
+ else:
408
+ sampled_positions = np.vstack([sampled_positions[0], sampled_positions, sampled_positions[-1], sampled_positions[-1]])
409
+
410
+ # 4. Generate smooth path positions
411
+ path_positions = []
412
+ for i in range(1, len(sampled_positions) - 2):
413
+ segment = catmull_rom_spline(sampled_positions[i-1], sampled_positions[i], sampled_positions[i+1], sampled_positions[i+2], n_points_per_segment)
414
+ path_positions.append(segment)
415
+ path_positions = np.concatenate(path_positions, axis=0)
416
+
417
+ # 5. Global SLERP for rotations
418
+ rotations = R.from_matrix([cam.R for cam in sampled_cameras])
419
+ key_times = np.linspace(0, 1, len(rotations))
420
+ slerp = Slerp(key_times, rotations)
421
+
422
+ query_times = np.linspace(0, 1, len(path_positions))
423
+ interpolated_rotations = slerp(query_times)
424
+
425
+ # 6. Generate Camera objects
426
+ reference_cam = existing_cameras[0]
427
+ smooth_cameras = []
428
+
429
+ for i, pos in enumerate(path_positions):
430
+ R_interp = interpolated_rotations[i].as_matrix()
431
+
432
+ smooth_cameras.append(Camera(
433
+ R=R_interp,
434
+ T=pos,
435
+ FoVx=reference_cam.FoVx,
436
+ FoVy=reference_cam.FoVy,
437
+ resolution=(reference_cam.image_width, reference_cam.image_height),
438
+ colmap_id=-1,
439
+ depth_params=None,
440
+ image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
441
+ invdepthmap=None,
442
+ image_name=f"fully_smooth_path_i={i}",
443
+ uid=i
444
+ ))
445
+
446
+ return smooth_cameras
447
+
448
+
449
+ def plot_cameras_and_smooth_path_with_orientation(existing_cameras: List[Camera], smooth_cameras: List[Camera], scale: float = 0.1):
450
+ """
451
+ Plot input cameras and smooth path cameras with their orientations in 3D.
452
+
453
+ Args:
454
+ existing_cameras (List[Camera]): List of original input cameras.
455
+ smooth_cameras (List[Camera]): List of smooth path cameras.
456
+ scale (float): Length of orientation arrows.
457
+
458
+ Returns:
459
+ None
460
+ """
461
+ # Input cameras
462
+ input_positions = np.array([cam.T for cam in existing_cameras])
463
+
464
+ # Smooth cameras
465
+ smooth_positions = np.array([cam.T for cam in smooth_cameras])
466
+
467
+ fig = go.Figure()
468
+
469
+ # Plot input camera positions
470
+ fig.add_trace(go.Scatter3d(
471
+ x=input_positions[:, 0], y=input_positions[:, 1], z=input_positions[:, 2],
472
+ mode='markers',
473
+ marker=dict(size=4, color='blue'),
474
+ name='Input Cameras'
475
+ ))
476
+
477
+ # Plot smooth path positions
478
+ fig.add_trace(go.Scatter3d(
479
+ x=smooth_positions[:, 0], y=smooth_positions[:, 1], z=smooth_positions[:, 2],
480
+ mode='lines+markers',
481
+ line=dict(color='red', width=3),
482
+ marker=dict(size=2, color='red'),
483
+ name='Smooth Path Cameras'
484
+ ))
485
+
486
+ # Plot input camera orientations
487
+ for cam in existing_cameras:
488
+ origin = cam.T
489
+ forward = cam.R[:, 2] # Forward direction
490
+
491
+ fig.add_trace(go.Cone(
492
+ x=[origin[0]], y=[origin[1]], z=[origin[2]],
493
+ u=[forward[0]], v=[forward[1]], w=[forward[2]],
494
+ colorscale=[[0, 'blue'], [1, 'blue']],
495
+ sizemode="absolute",
496
+ sizeref=scale,
497
+ anchor="tail",
498
+ showscale=False,
499
+ name='Input Camera Direction'
500
+ ))
501
+
502
+ # Plot smooth camera orientations
503
+ for cam in smooth_cameras:
504
+ origin = cam.T
505
+ forward = cam.R[:, 2] # Forward direction
506
+
507
+ fig.add_trace(go.Cone(
508
+ x=[origin[0]], y=[origin[1]], z=[origin[2]],
509
+ u=[forward[0]], v=[forward[1]], w=[forward[2]],
510
+ colorscale=[[0, 'red'], [1, 'red']],
511
+ sizemode="absolute",
512
+ sizeref=scale,
513
+ anchor="tail",
514
+ showscale=False,
515
+ name='Smooth Camera Direction'
516
+ ))
517
+
518
+ fig.update_layout(
519
+ scene=dict(
520
+ xaxis_title='X',
521
+ yaxis_title='Y',
522
+ zaxis_title='Z',
523
+ aspectmode='data'
524
+ ),
525
+ title="Input Cameras and Smooth Path with Orientations",
526
+ margin=dict(l=0, r=0, b=0, t=30)
527
+ )
528
+
529
+ fig.show()
530
+
531
+
532
+ def solve_tsp_nearest_neighbor(points: np.ndarray):
533
+ """
534
+ Solve TSP approximately using nearest neighbor heuristic.
535
+
536
+ Args:
537
+ points (np.ndarray): (N, 3) array of points.
538
+
539
+ Returns:
540
+ List[int]: Optimal visiting order of points.
541
+ """
542
+ N = points.shape[0]
543
+ dist = distance_matrix(points, points)
544
+ visited = [0]
545
+ unvisited = set(range(1, N))
546
+
547
+ while unvisited:
548
+ last = visited[-1]
549
+ next_city = min(unvisited, key=lambda city: dist[last, city])
550
+ visited.append(next_city)
551
+ unvisited.remove(next_city)
552
+
553
+ return visited
554
+
555
+ def solve_tsp_2opt(points: np.ndarray, n_iter: int = 1000) -> np.ndarray:
556
+ """
557
+ Solve TSP approximately using Nearest Neighbor + 2-Opt.
558
+
559
+ Args:
560
+ points (np.ndarray): Array of shape (N, D) with points.
561
+ n_iter (int): Number of 2-opt iterations.
562
+
563
+ Returns:
564
+ np.ndarray: Ordered list of indices.
565
+ """
566
+ n_points = points.shape[0]
567
+
568
+ # === 1. Start with Nearest Neighbor
569
+ unvisited = list(range(n_points))
570
+ current = unvisited.pop(0)
571
+ path = [current]
572
+
573
+ while unvisited:
574
+ dists = np.linalg.norm(points[unvisited] - points[current], axis=1)
575
+ next_idx = unvisited[np.argmin(dists)]
576
+ unvisited.remove(next_idx)
577
+ path.append(next_idx)
578
+ current = next_idx
579
+
580
+ # === 2. Apply 2-Opt improvements
581
+ def path_length(path):
582
+ return np.sum(np.linalg.norm(points[path[i]] - points[path[i+1]], axis=0) for i in range(len(path)-1))
583
+
584
+ best_length = path_length(path)
585
+ improved = True
586
+
587
+ for _ in range(n_iter):
588
+ if not improved:
589
+ break
590
+ improved = False
591
+ for i in range(1, n_points - 2):
592
+ for j in range(i + 1, n_points):
593
+ if j - i == 1: continue
594
+ new_path = path[:i] + path[i:j][::-1] + path[j:]
595
+ new_length = path_length(new_path)
596
+ if new_length < best_length:
597
+ path = new_path
598
+ best_length = new_length
599
+ improved = True
600
+ break
601
+ if improved:
602
+ break
603
+
604
+ return np.array(path)
605
+
606
+ def generate_fully_smooth_cameras_with_tsp(existing_cameras: List[Camera],
607
+ n_selected: int = 30,
608
+ n_points_per_segment: int = 20,
609
+ d: float = 2.0,
610
+ closed: bool = False) -> List[Camera]:
611
+ """
612
+ Generate a fully smooth camera path using TSP ordering, global Catmull-Rom spline for positions, and global SLERP for orientations.
613
+
614
+ Args:
615
+ existing_cameras (List[Camera]): List of input cameras.
616
+ n_selected (int): Number of cameras to select after ordering.
617
+ n_points_per_segment (int): Number of interpolated points per spline segment.
618
+ d (float): Distance ahead for estimating center of view.
619
+ closed (bool): Whether to close the path.
620
+
621
+ Returns:
622
+ List[Camera]: List of smoothly moving Camera objects.
623
+ """
624
+ positions = np.array([cam.T for cam in existing_cameras])
625
+
626
+ # 1. Solve approximate TSP
627
+ order = solve_tsp_nearest_neighbor(positions)
628
+ ordered_cameras = [existing_cameras[i] for i in order]
629
+ ordered_positions = positions[order]
630
+
631
+ # 2. Subsample uniformly
632
+ idx = np.linspace(0, len(ordered_positions) - 1, n_selected).astype(int)
633
+ sampled_positions = ordered_positions[idx]
634
+ sampled_cameras = [ordered_cameras[i] for i in idx]
635
+
636
+ # 3. Prepare for Catmull-Rom
637
+ if closed:
638
+ sampled_positions = np.vstack([sampled_positions[-1], sampled_positions, sampled_positions[0], sampled_positions[1]])
639
+ else:
640
+ sampled_positions = np.vstack([sampled_positions[0], sampled_positions, sampled_positions[-1], sampled_positions[-1]])
641
+
642
+ # 4. Generate smooth path positions
643
+ path_positions = []
644
+ for i in range(1, len(sampled_positions) - 2):
645
+ segment = catmull_rom_spline(sampled_positions[i-1], sampled_positions[i], sampled_positions[i+1], sampled_positions[i+2], n_points_per_segment)
646
+ path_positions.append(segment)
647
+ path_positions = np.concatenate(path_positions, axis=0)
648
+
649
+ # 5. Global SLERP for rotations
650
+ rotations = R.from_matrix([cam.R for cam in sampled_cameras])
651
+ key_times = np.linspace(0, 1, len(rotations))
652
+ slerp = Slerp(key_times, rotations)
653
+
654
+ query_times = np.linspace(0, 1, len(path_positions))
655
+ interpolated_rotations = slerp(query_times)
656
+
657
+ # 6. Generate Camera objects
658
+ reference_cam = existing_cameras[0]
659
+ smooth_cameras = []
660
+
661
+ for i, pos in enumerate(path_positions):
662
+ R_interp = interpolated_rotations[i].as_matrix()
663
+
664
+ smooth_cameras.append(Camera(
665
+ R=R_interp,
666
+ T=pos,
667
+ FoVx=reference_cam.FoVx,
668
+ FoVy=reference_cam.FoVy,
669
+ resolution=(reference_cam.image_width, reference_cam.image_height),
670
+ colmap_id=-1,
671
+ depth_params=None,
672
+ image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
673
+ invdepthmap=None,
674
+ image_name=f"fully_smooth_path_i={i}",
675
+ uid=i
676
+ ))
677
+
678
+ return smooth_cameras
679
+
680
+ from typing import List
681
+ import numpy as np
682
+ from sklearn.mixture import GaussianMixture
683
+ from scipy.spatial.transform import Rotation as R, Slerp
684
+ from PIL import Image
685
+
686
+ def generate_clustered_smooth_cameras_with_tsp(existing_cameras: List[Camera],
687
+ n_selected: int = 30,
688
+ n_points_per_segment: int = 20,
689
+ d: float = 2.0,
690
+ n_clusters: int = 5,
691
+ closed: bool = False) -> List[Camera]:
692
+ """
693
+ Generate a fully smooth camera path using clustering + TSP between nearest cluster centers + TSP inside clusters.
694
+ Positions are normalized before clustering and denormalized before generating final cameras.
695
+
696
+ Args:
697
+ existing_cameras (List[Camera]): List of input cameras.
698
+ n_selected (int): Number of cameras to select after ordering.
699
+ n_points_per_segment (int): Number of interpolated points per spline segment.
700
+ d (float): Distance ahead for estimating center of view.
701
+ n_clusters (int): Number of GMM clusters.
702
+ closed (bool): Whether to close the path.
703
+
704
+ Returns:
705
+ List[Camera]: Smooth path of Camera objects.
706
+ """
707
+ # Extract positions and rotations
708
+ positions = np.array([cam.T for cam in existing_cameras])
709
+ rotations = np.array([R.from_matrix(cam.R).as_quat() for cam in existing_cameras])
710
+
711
+ # === Normalize positions
712
+ mean_pos = np.mean(positions, axis=0)
713
+ scale_pos = np.std(positions, axis=0)
714
+ scale_pos[scale_pos == 0] = 1.0 # avoid division by zero
715
+
716
+ positions_normalized = (positions - mean_pos) / scale_pos
717
+
718
+ # === Features for clustering (only positions, not rotations)
719
+ features = positions_normalized
720
+
721
+ # === 1. GMM clustering
722
+ gmm = GaussianMixture(n_components=n_clusters, covariance_type='full', random_state=42)
723
+ cluster_labels = gmm.fit_predict(features)
724
+
725
+ clusters = {}
726
+ cluster_centers = []
727
+
728
+ for cluster_id in range(n_clusters):
729
+ cluster_indices = np.where(cluster_labels == cluster_id)[0]
730
+ if len(cluster_indices) == 0:
731
+ continue
732
+ clusters[cluster_id] = cluster_indices
733
+ cluster_center = np.mean(features[cluster_indices], axis=0)
734
+ cluster_centers.append(cluster_center)
735
+
736
+ cluster_centers = np.stack(cluster_centers)
737
+
738
+ # === 2. Remap cluster centers to nearest existing cameras
739
+ if False:
740
+ mapped_centers = []
741
+ for center in cluster_centers:
742
+ dists = np.linalg.norm(features - center, axis=1)
743
+ nearest_idx = np.argmin(dists)
744
+ mapped_centers.append(features[nearest_idx])
745
+ mapped_centers = np.stack(mapped_centers)
746
+ cluster_centers = mapped_centers
747
+ # === 3. Solve TSP between mapped cluster centers
748
+ cluster_order = solve_tsp_2opt(cluster_centers)
749
+
750
+ # === 4. For each cluster, solve TSP inside cluster
751
+ final_indices = []
752
+ for cluster_id in cluster_order:
753
+ cluster_indices = clusters[cluster_id]
754
+ cluster_positions = features[cluster_indices]
755
+
756
+ if len(cluster_positions) == 1:
757
+ final_indices.append(cluster_indices[0])
758
+ continue
759
+
760
+ local_order = solve_tsp_nearest_neighbor(cluster_positions)
761
+ ordered_cluster_indices = cluster_indices[local_order]
762
+ final_indices.extend(ordered_cluster_indices)
763
+
764
+ ordered_cameras = [existing_cameras[i] for i in final_indices]
765
+ ordered_positions = positions_normalized[final_indices]
766
+
767
+ # === 5. Subsample uniformly
768
+ idx = np.linspace(0, len(ordered_positions) - 1, n_selected).astype(int)
769
+ sampled_positions = ordered_positions[idx]
770
+ sampled_cameras = [ordered_cameras[i] for i in idx]
771
+
772
+ # === 6. Prepare for Catmull-Rom spline
773
+ if closed:
774
+ sampled_positions = np.vstack([sampled_positions[-1], sampled_positions, sampled_positions[0], sampled_positions[1]])
775
+ else:
776
+ sampled_positions = np.vstack([sampled_positions[0], sampled_positions, sampled_positions[-1], sampled_positions[-1]])
777
+
778
+ # === 7. Smooth path positions
779
+ path_positions = []
780
+ for i in range(1, len(sampled_positions) - 2):
781
+ segment = catmull_rom_spline(sampled_positions[i-1], sampled_positions[i], sampled_positions[i+1], sampled_positions[i+2], n_points_per_segment)
782
+ path_positions.append(segment)
783
+ path_positions = np.concatenate(path_positions, axis=0)
784
+
785
+ # === 8. Denormalize
786
+ path_positions = path_positions * scale_pos + mean_pos
787
+
788
+ # === 9. SLERP for rotations
789
+ rotations = R.from_matrix([cam.R for cam in sampled_cameras])
790
+ key_times = np.linspace(0, 1, len(rotations))
791
+ slerp = Slerp(key_times, rotations)
792
+
793
+ query_times = np.linspace(0, 1, len(path_positions))
794
+ interpolated_rotations = slerp(query_times)
795
+
796
+ # === 10. Generate Camera objects
797
+ reference_cam = existing_cameras[0]
798
+ smooth_cameras = []
799
+
800
+ for i, pos in enumerate(path_positions):
801
+ R_interp = interpolated_rotations[i].as_matrix()
802
+
803
+ smooth_cameras.append(Camera(
804
+ R=R_interp,
805
+ T=pos,
806
+ FoVx=reference_cam.FoVx,
807
+ FoVy=reference_cam.FoVy,
808
+ resolution=(reference_cam.image_width, reference_cam.image_height),
809
+ colmap_id=-1,
810
+ depth_params=None,
811
+ image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
812
+ invdepthmap=None,
813
+ image_name=f"clustered_smooth_path_i={i}",
814
+ uid=i
815
+ ))
816
+
817
+ return smooth_cameras
818
+
819
+
820
+ # def generate_clustered_path(existing_cameras: List[Camera],
821
+ # n_points_per_segment: int = 20,
822
+ # d: float = 2.0,
823
+ # n_clusters: int = 5,
824
+ # closed: bool = False) -> List[Camera]:
825
+ # """
826
+ # Generate a smooth camera path using GMM clustering and TSP on cluster centers.
827
+
828
+ # Args:
829
+ # existing_cameras (List[Camera]): List of input cameras.
830
+ # n_points_per_segment (int): Number of interpolated points per spline segment.
831
+ # d (float): Distance ahead for estimating center of view.
832
+ # n_clusters (int): Number of GMM clusters (zones).
833
+ # closed (bool): Whether to close the path.
834
+
835
+ # Returns:
836
+ # List[Camera]: Smooth path of Camera objects.
837
+ # """
838
+ # # Extract positions and rotations
839
+ # positions = np.array([cam.T for cam in existing_cameras])
840
+
841
+ # # === Normalize positions
842
+ # mean_pos = np.mean(positions, axis=0)
843
+ # scale_pos = np.std(positions, axis=0)
844
+ # scale_pos[scale_pos == 0] = 1.0
845
+
846
+ # positions_normalized = (positions - mean_pos) / scale_pos
847
+
848
+ # # === 1. GMM clustering (only positions)
849
+ # gmm = GaussianMixture(n_components=n_clusters, covariance_type='full', random_state=42)
850
+ # cluster_labels = gmm.fit_predict(positions_normalized)
851
+
852
+ # cluster_centers = []
853
+ # for cluster_id in range(n_clusters):
854
+ # cluster_indices = np.where(cluster_labels == cluster_id)[0]
855
+ # if len(cluster_indices) == 0:
856
+ # continue
857
+ # cluster_center = np.mean(positions_normalized[cluster_indices], axis=0)
858
+ # cluster_centers.append(cluster_center)
859
+
860
+ # cluster_centers = np.stack(cluster_centers)
861
+
862
+ # # === 2. Solve TSP between cluster centers
863
+ # cluster_order = solve_tsp_2opt(cluster_centers)
864
+
865
+ # # === 3. Reorder cluster centers
866
+ # ordered_centers = cluster_centers[cluster_order]
867
+
868
+ # # === 4. Prepare Catmull-Rom spline
869
+ # if closed:
870
+ # ordered_centers = np.vstack([ordered_centers[-1], ordered_centers, ordered_centers[0], ordered_centers[1]])
871
+ # else:
872
+ # ordered_centers = np.vstack([ordered_centers[0], ordered_centers, ordered_centers[-1], ordered_centers[-1]])
873
+
874
+ # # === 5. Generate smooth path positions
875
+ # path_positions = []
876
+ # for i in range(1, len(ordered_centers) - 2):
877
+ # segment = catmull_rom_spline(ordered_centers[i-1], ordered_centers[i], ordered_centers[i+1], ordered_centers[i+2], n_points_per_segment)
878
+ # path_positions.append(segment)
879
+ # path_positions = np.concatenate(path_positions, axis=0)
880
+
881
+ # # === 6. Denormalize back
882
+ # path_positions = path_positions * scale_pos + mean_pos
883
+
884
+ # # === 7. Generate dummy rotations (constant forward facing)
885
+ # reference_cam = existing_cameras[0]
886
+ # default_rotation = R.from_matrix(reference_cam.R)
887
+
888
+ # # For simplicity, fixed rotation for all
889
+ # smooth_cameras = []
890
+
891
+ # for i, pos in enumerate(path_positions):
892
+ # R_interp = default_rotation.as_matrix()
893
+
894
+ # smooth_cameras.append(Camera(
895
+ # R=R_interp,
896
+ # T=pos,
897
+ # FoVx=reference_cam.FoVx,
898
+ # FoVy=reference_cam.FoVy,
899
+ # resolution=(reference_cam.image_width, reference_cam.image_height),
900
+ # colmap_id=-1,
901
+ # depth_params=None,
902
+ # image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
903
+ # invdepthmap=None,
904
+ # image_name=f"cluster_path_i={i}",
905
+ # uid=i
906
+ # ))
907
+
908
+ # return smooth_cameras
909
+
910
+ from typing import List
911
+ import numpy as np
912
+ from sklearn.cluster import KMeans
913
+ from scipy.spatial.transform import Rotation as R, Slerp
914
+ from PIL import Image
915
+
916
+ def generate_clustered_path(existing_cameras: List[Camera],
917
+ n_points_per_segment: int = 20,
918
+ d: float = 2.0,
919
+ n_clusters: int = 5,
920
+ closed: bool = False) -> List[Camera]:
921
+ """
922
+ Generate a smooth camera path using K-Means clustering and TSP on cluster centers.
923
+
924
+ Args:
925
+ existing_cameras (List[Camera]): List of input cameras.
926
+ n_points_per_segment (int): Number of interpolated points per spline segment.
927
+ d (float): Distance ahead for estimating center of view.
928
+ n_clusters (int): Number of KMeans clusters (zones).
929
+ closed (bool): Whether to close the path.
930
+
931
+ Returns:
932
+ List[Camera]: Smooth path of Camera objects.
933
+ """
934
+ # Extract positions
935
+ positions = np.array([cam.T for cam in existing_cameras])
936
+
937
+ # === Normalize positions
938
+ mean_pos = np.mean(positions, axis=0)
939
+ scale_pos = np.std(positions, axis=0)
940
+ scale_pos[scale_pos == 0] = 1.0
941
+
942
+ positions_normalized = (positions - mean_pos) / scale_pos
943
+
944
+ # === 1. K-Means clustering (only positions)
945
+ kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init='auto')
946
+ cluster_labels = kmeans.fit_predict(positions_normalized)
947
+
948
+ cluster_centers = []
949
+ for cluster_id in range(n_clusters):
950
+ cluster_indices = np.where(cluster_labels == cluster_id)[0]
951
+ if len(cluster_indices) == 0:
952
+ continue
953
+ cluster_center = np.mean(positions_normalized[cluster_indices], axis=0)
954
+ cluster_centers.append(cluster_center)
955
+
956
+ cluster_centers = np.stack(cluster_centers)
957
+
958
+ # === 2. Solve TSP between cluster centers
959
+ cluster_order = solve_tsp_2opt(cluster_centers)
960
+
961
+ # === 3. Reorder cluster centers
962
+ ordered_centers = cluster_centers[cluster_order]
963
+
964
+ # === 4. Prepare Catmull-Rom spline
965
+ if closed:
966
+ ordered_centers = np.vstack([ordered_centers[-1], ordered_centers, ordered_centers[0], ordered_centers[1]])
967
+ else:
968
+ ordered_centers = np.vstack([ordered_centers[0], ordered_centers, ordered_centers[-1], ordered_centers[-1]])
969
+
970
+ # === 5. Generate smooth path positions
971
+ path_positions = []
972
+ for i in range(1, len(ordered_centers) - 2):
973
+ segment = catmull_rom_spline(ordered_centers[i-1], ordered_centers[i], ordered_centers[i+1], ordered_centers[i+2], n_points_per_segment)
974
+ path_positions.append(segment)
975
+ path_positions = np.concatenate(path_positions, axis=0)
976
+
977
+ # === 6. Denormalize back
978
+ path_positions = path_positions * scale_pos + mean_pos
979
+
980
+ # === 7. Generate dummy rotations (constant forward facing)
981
+ reference_cam = existing_cameras[0]
982
+ default_rotation = R.from_matrix(reference_cam.R)
983
+
984
+ # For simplicity, fixed rotation for all
985
+ smooth_cameras = []
986
+
987
+ for i, pos in enumerate(path_positions):
988
+ R_interp = default_rotation.as_matrix()
989
+
990
+ smooth_cameras.append(Camera(
991
+ R=R_interp,
992
+ T=pos,
993
+ FoVx=reference_cam.FoVx,
994
+ FoVy=reference_cam.FoVy,
995
+ resolution=(reference_cam.image_width, reference_cam.image_height),
996
+ colmap_id=-1,
997
+ depth_params=None,
998
+ image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
999
+ invdepthmap=None,
1000
+ image_name=f"cluster_path_i={i}",
1001
+ uid=i
1002
+ ))
1003
+
1004
+ return smooth_cameras
1005
+
1006
+
1007
+
1008
+
1009
+ def visualize_image_with_points(image, points):
1010
+ """
1011
+ Visualize an image with points overlaid on top. This is useful for correspondences visualizations
1012
+
1013
+ Parameters:
1014
+ - image: PIL Image object
1015
+ - points: Numpy array of shape [N, 2] containing (x, y) coordinates of points
1016
+
1017
+ Returns:
1018
+ - None (displays the visualization)
1019
+ """
1020
+
1021
+ # Convert PIL image to numpy array
1022
+ img_array = np.array(image)
1023
+
1024
+ # Create a figure and axis
1025
+ fig, ax = plt.subplots(figsize=(7,7))
1026
+
1027
+ # Display the image
1028
+ ax.imshow(img_array)
1029
+
1030
+ # Scatter plot the points on top of the image
1031
+ ax.scatter(points[:, 0], points[:, 1], color='red', marker='o', s=1)
1032
+
1033
+ # Show the plot
1034
+ plt.show()
1035
+
1036
+
1037
+ def visualize_correspondences(image1, points1, image2, points2):
1038
+ """
1039
+ Visualize two images concatenated horizontally with key points and correspondences.
1040
+
1041
+ Parameters:
1042
+ - image1: PIL Image object (left image)
1043
+ - points1: Numpy array of shape [N, 2] containing (x, y) coordinates of key points for image1
1044
+ - image2: PIL Image object (right image)
1045
+ - points2: Numpy array of shape [N, 2] containing (x, y) coordinates of key points for image2
1046
+
1047
+ Returns:
1048
+ - None (displays the visualization)
1049
+ """
1050
+
1051
+ # Concatenate images horizontally
1052
+ concatenated_image = np.concatenate((np.array(image1), np.array(image2)), axis=1)
1053
+
1054
+ # Create a figure and axis
1055
+ fig, ax = plt.subplots(figsize=(10,10))
1056
+
1057
+ # Display the concatenated image
1058
+ ax.imshow(concatenated_image)
1059
+
1060
+ # Plot key points on the left image
1061
+ ax.scatter(points1[:, 0], points1[:, 1], color='red', marker='o', s=10)
1062
+
1063
+ # Plot key points on the right image
1064
+ ax.scatter(points2[:, 0] + image1.width, points2[:, 1], color='blue', marker='o', s=10)
1065
+
1066
+ # Draw lines connecting corresponding key points
1067
+ for i in range(len(points1)):
1068
+ ax.plot([points1[i, 0], points2[i, 0] + image1.width], [points1[i, 1], points2[i, 1]])#, color='green')
1069
+
1070
+ # Show the plot
1071
+ plt.show()
1072
+
submodules/RoMa/.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.egg-info*
2
+ *.vscode*
3
+ *__pycache__*
4
+ vis*
5
+ workspace*
6
+ .venv
7
+ .DS_Store
8
+ jobs/*
9
+ *ignore_me*
10
+ *.pth
11
+ wandb*
submodules/RoMa/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Johan Edstedt
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
submodules/RoMa/README.md ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ <p align="center">
3
+ <h1 align="center"> <ins>RoMa</ins> 🏛️:<br> Robust Dense Feature Matching <br> ⭐CVPR 2024⭐</h1>
4
+ <p align="center">
5
+ <a href="https://scholar.google.com/citations?user=Ul-vMR0AAAAJ">Johan Edstedt</a>
6
+ ·
7
+ <a href="https://scholar.google.com/citations?user=HS2WuHkAAAAJ">Qiyu Sun</a>
8
+ ·
9
+ <a href="https://scholar.google.com/citations?user=FUE3Wd0AAAAJ">Georg Bökman</a>
10
+ ·
11
+ <a href="https://scholar.google.com/citations?user=6WRQpCQAAAAJ">Mårten Wadenbäck</a>
12
+ ·
13
+ <a href="https://scholar.google.com/citations?user=lkWfR08AAAAJ">Michael Felsberg</a>
14
+ </p>
15
+ <h2 align="center"><p>
16
+ <a href="https://arxiv.org/abs/2305.15404" align="center">Paper</a> |
17
+ <a href="https://parskatt.github.io/RoMa" align="center">Project Page</a>
18
+ </p></h2>
19
+ <div align="center"></div>
20
+ </p>
21
+ <br/>
22
+ <p align="center">
23
+ <img src="https://github.com/Parskatt/RoMa/assets/22053118/15d8fea7-aa6d-479f-8a93-350d950d006b" alt="example" width=80%>
24
+ <br>
25
+ <em>RoMa is the robust dense feature matcher capable of estimating pixel-dense warps and reliable certainties for almost any image pair.</em>
26
+ </p>
27
+
28
+ ## Setup/Install
29
+ In your python environment (tested on Linux python 3.10), run:
30
+ ```bash
31
+ pip install -e .
32
+ ```
33
+ ## Demo / How to Use
34
+ We provide two demos in the [demos folder](demo).
35
+ Here's the gist of it:
36
+ ```python
37
+ from romatch import roma_outdoor
38
+ roma_model = roma_outdoor(device=device)
39
+ # Match
40
+ warp, certainty = roma_model.match(imA_path, imB_path, device=device)
41
+ # Sample matches for estimation
42
+ matches, certainty = roma_model.sample(warp, certainty)
43
+ # Convert to pixel coordinates (RoMa produces matches in [-1,1]x[-1,1])
44
+ kptsA, kptsB = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
45
+ # Find a fundamental matrix (or anything else of interest)
46
+ F, mask = cv2.findFundamentalMat(
47
+ kptsA.cpu().numpy(), kptsB.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
48
+ )
49
+ ```
50
+
51
+ **New**: You can also match arbitrary keypoints with RoMa. See [match_keypoints](romatch/models/matcher.py) in RegressionMatcher.
52
+
53
+ ## Settings
54
+
55
+ ### Resolution
56
+ By default RoMa uses an initial resolution of (560,560) which is then upsampled to (864,864).
57
+ You can change this at construction (see roma_outdoor kwargs).
58
+ You can also change this later, by changing the roma_model.w_resized, roma_model.h_resized, and roma_model.upsample_res.
59
+
60
+ ### Sampling
61
+ roma_model.sample_thresh controls the thresholding used when sampling matches for estimation. In certain cases a lower or higher threshold may improve results.
62
+
63
+
64
+ ## Reproducing Results
65
+ The experiments in the paper are provided in the [experiments folder](experiments).
66
+
67
+ ### Training
68
+ 1. First follow the instructions provided here: https://github.com/Parskatt/DKM for downloading and preprocessing datasets.
69
+ 2. Run the relevant experiment, e.g.,
70
+ ```bash
71
+ torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d experiments/roma_outdoor.py
72
+ ```
73
+ ### Testing
74
+ ```bash
75
+ python experiments/roma_outdoor.py --only_test --benchmark mega-1500
76
+ ```
77
+ ## License
78
+ All our code except DINOv2 is MIT license.
79
+ DINOv2 has an Apache 2 license [DINOv2](https://github.com/facebookresearch/dinov2/blob/main/LICENSE).
80
+
81
+ ## Acknowledgement
82
+ Our codebase builds on the code in [DKM](https://github.com/Parskatt/DKM).
83
+
84
+ ## Tiny RoMa
85
+ If you find that RoMa is too heavy, you might want to try Tiny RoMa which is built on top of XFeat.
86
+ ```python
87
+ from romatch import tiny_roma_v1_outdoor
88
+ tiny_roma_model = tiny_roma_v1_outdoor(device=device)
89
+ ```
90
+ Mega1500:
91
+ | | AUC@5 | AUC@10 | AUC@20 |
92
+ |----------|----------|----------|----------|
93
+ | XFeat | 46.4 | 58.9 | 69.2 |
94
+ | XFeat* | 51.9 | 67.2 | 78.9 |
95
+ | Tiny RoMa v1 | 56.4 | 69.5 | 79.5 |
96
+ | RoMa | - | - | - |
97
+
98
+ Mega-8-Scenes (See DKM):
99
+ | | AUC@5 | AUC@10 | AUC@20 |
100
+ |----------|----------|----------|----------|
101
+ | XFeat | - | - | - |
102
+ | XFeat* | 50.1 | 64.4 | 75.2 |
103
+ | Tiny RoMa v1 | 57.7 | 70.5 | 79.6 |
104
+ | RoMa | - | - | - |
105
+
106
+ IMC22 :'):
107
+ | | mAA@10 |
108
+ |----------|----------|
109
+ | XFeat | 42.1 |
110
+ | XFeat* | - |
111
+ | Tiny RoMa v1 | 42.2 |
112
+ | RoMa | - |
113
+
114
+ ## BibTeX
115
+ If you find our models useful, please consider citing our paper!
116
+ ```
117
+ @article{edstedt2024roma,
118
+ title={{RoMa: Robust Dense Feature Matching}},
119
+ author={Edstedt, Johan and Sun, Qiyu and Bökman, Georg and Wadenbäck, Mårten and Felsberg, Michael},
120
+ journal={IEEE Conference on Computer Vision and Pattern Recognition},
121
+ year={2024}
122
+ }
123
+ ```
submodules/RoMa/data/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
submodules/RoMa/demo/demo_3D_effect.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from romatch.utils.utils import tensor_to_pil
6
+
7
+ from romatch import roma_outdoor
8
+
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+ if torch.backends.mps.is_available():
11
+ device = torch.device('mps')
12
+
13
+ if __name__ == "__main__":
14
+ from argparse import ArgumentParser
15
+ parser = ArgumentParser()
16
+ parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
17
+ parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
18
+ parser.add_argument("--save_path", default="demo/gif/roma_warp_toronto", type=str)
19
+
20
+ args, _ = parser.parse_known_args()
21
+ im1_path = args.im_A_path
22
+ im2_path = args.im_B_path
23
+ save_path = args.save_path
24
+
25
+ # Create model
26
+ roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
27
+ roma_model.symmetric = False
28
+
29
+ H, W = roma_model.get_output_resolution()
30
+
31
+ im1 = Image.open(im1_path).resize((W, H))
32
+ im2 = Image.open(im2_path).resize((W, H))
33
+
34
+ # Match
35
+ warp, certainty = roma_model.match(im1_path, im2_path, device=device)
36
+ # Sampling not needed, but can be done with model.sample(warp, certainty)
37
+ x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
38
+ x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
39
+
40
+ coords_A, coords_B = warp[...,:2], warp[...,2:]
41
+ for i, x in enumerate(np.linspace(0,2*np.pi,200)):
42
+ t = (1 + np.cos(x))/2
43
+ interp_warp = (1-t)*coords_A + t*coords_B
44
+ im2_transfer_rgb = F.grid_sample(
45
+ x2[None], interp_warp[None], mode="bilinear", align_corners=False
46
+ )[0]
47
+ tensor_to_pil(im2_transfer_rgb, unnormalize=False).save(f"{save_path}_{i:03d}.jpg")
submodules/RoMa/demo/demo_fundamental.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import cv2
4
+ from romatch import roma_outdoor
5
+
6
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
+ if torch.backends.mps.is_available():
8
+ device = torch.device('mps')
9
+
10
+ if __name__ == "__main__":
11
+ from argparse import ArgumentParser
12
+ parser = ArgumentParser()
13
+ parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
14
+ parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
15
+
16
+ args, _ = parser.parse_known_args()
17
+ im1_path = args.im_A_path
18
+ im2_path = args.im_B_path
19
+
20
+ # Create model
21
+ roma_model = roma_outdoor(device=device)
22
+
23
+
24
+ W_A, H_A = Image.open(im1_path).size
25
+ W_B, H_B = Image.open(im2_path).size
26
+
27
+ # Match
28
+ warp, certainty = roma_model.match(im1_path, im2_path, device=device)
29
+ # Sample matches for estimation
30
+ matches, certainty = roma_model.sample(warp, certainty)
31
+ kpts1, kpts2 = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
32
+ F, mask = cv2.findFundamentalMat(
33
+ kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
34
+ )
submodules/RoMa/demo/demo_match.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
3
+ import torch
4
+ from PIL import Image
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ from romatch.utils.utils import tensor_to_pil
8
+
9
+ from romatch import roma_outdoor
10
+
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+ if torch.backends.mps.is_available():
13
+ device = torch.device('mps')
14
+
15
+ if __name__ == "__main__":
16
+ from argparse import ArgumentParser
17
+ parser = ArgumentParser()
18
+ parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
19
+ parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
20
+ parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
21
+
22
+ args, _ = parser.parse_known_args()
23
+ im1_path = args.im_A_path
24
+ im2_path = args.im_B_path
25
+ save_path = args.save_path
26
+
27
+ # Create model
28
+ roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
29
+
30
+ H, W = roma_model.get_output_resolution()
31
+
32
+ im1 = Image.open(im1_path).resize((W, H))
33
+ im2 = Image.open(im2_path).resize((W, H))
34
+
35
+ # Match
36
+ warp, certainty = roma_model.match(im1_path, im2_path, device=device)
37
+ # Sampling not needed, but can be done with model.sample(warp, certainty)
38
+ x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
39
+ x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
40
+
41
+ im2_transfer_rgb = F.grid_sample(
42
+ x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
43
+ )[0]
44
+ im1_transfer_rgb = F.grid_sample(
45
+ x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
46
+ )[0]
47
+ warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2)
48
+ white_im = torch.ones((H,2*W),device=device)
49
+ vis_im = certainty * warp_im + (1 - certainty) * white_im
50
+ tensor_to_pil(vis_im, unnormalize=False).save(save_path)
submodules/RoMa/demo/demo_match_opencv_sift.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+
4
+ import numpy as np
5
+ import cv2 as cv
6
+ import matplotlib.pyplot as plt
7
+
8
+
9
+
10
+ if __name__ == "__main__":
11
+ from argparse import ArgumentParser
12
+ parser = ArgumentParser()
13
+ parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
14
+ parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
15
+ parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
16
+
17
+ args, _ = parser.parse_known_args()
18
+ im1_path = args.im_A_path
19
+ im2_path = args.im_B_path
20
+ save_path = args.save_path
21
+
22
+ img1 = cv.imread(im1_path,cv.IMREAD_GRAYSCALE) # queryImage
23
+ img2 = cv.imread(im2_path,cv.IMREAD_GRAYSCALE) # trainImage
24
+ # Initiate SIFT detector
25
+ sift = cv.SIFT_create()
26
+ # find the keypoints and descriptors with SIFT
27
+ kp1, des1 = sift.detectAndCompute(img1,None)
28
+ kp2, des2 = sift.detectAndCompute(img2,None)
29
+ # BFMatcher with default params
30
+ bf = cv.BFMatcher()
31
+ matches = bf.knnMatch(des1,des2,k=2)
32
+ # Apply ratio test
33
+ good = []
34
+ for m,n in matches:
35
+ if m.distance < 0.75*n.distance:
36
+ good.append([m])
37
+ # cv.drawMatchesKnn expects list of lists as matches.
38
+ draw_params = dict(matchColor = (255,0,0), # draw matches in red color
39
+ singlePointColor = None,
40
+ flags = 2)
41
+
42
+ img3 = cv.drawMatchesKnn(img1,kp1,img2,kp2,good,None,**draw_params)
43
+ Image.fromarray(img3).save("demo/sift_matches.png")
submodules/RoMa/demo/demo_match_tiny.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
3
+ import torch
4
+ from PIL import Image
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ from romatch.utils.utils import tensor_to_pil
8
+
9
+ from romatch import tiny_roma_v1_outdoor
10
+
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+ if torch.backends.mps.is_available():
13
+ device = torch.device('mps')
14
+
15
+ if __name__ == "__main__":
16
+ from argparse import ArgumentParser
17
+ parser = ArgumentParser()
18
+ parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
19
+ parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
20
+ parser.add_argument("--save_A_path", default="demo/tiny_roma_warp_A.jpg", type=str)
21
+ parser.add_argument("--save_B_path", default="demo/tiny_roma_warp_B.jpg", type=str)
22
+
23
+ args, _ = parser.parse_known_args()
24
+ im1_path = args.im_A_path
25
+ im2_path = args.im_B_path
26
+
27
+ # Create model
28
+ roma_model = tiny_roma_v1_outdoor(device=device)
29
+
30
+ # Match
31
+ warp, certainty1 = roma_model.match(im1_path, im2_path)
32
+
33
+ h1, w1 = warp.shape[:2]
34
+
35
+ # maybe im1.size != im2.size
36
+ im1 = Image.open(im1_path).resize((w1, h1))
37
+ im2 = Image.open(im2_path)
38
+ x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
39
+ x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
40
+
41
+ h2, w2 = x2.shape[1:]
42
+ g1_p2x = w2 / 2 * (warp[..., 2] + 1)
43
+ g1_p2y = h2 / 2 * (warp[..., 3] + 1)
44
+ g2_p1x = torch.zeros((h2, w2), dtype=torch.float32).to(device) - 2
45
+ g2_p1y = torch.zeros((h2, w2), dtype=torch.float32).to(device) - 2
46
+
47
+ x, y = torch.meshgrid(
48
+ torch.arange(w1, device=device),
49
+ torch.arange(h1, device=device),
50
+ indexing="xy",
51
+ )
52
+ g2x = torch.round(g1_p2x[y, x]).long()
53
+ g2y = torch.round(g1_p2y[y, x]).long()
54
+ idx_x = torch.bitwise_and(0 <= g2x, g2x < w2)
55
+ idx_y = torch.bitwise_and(0 <= g2y, g2y < h2)
56
+ idx = torch.bitwise_and(idx_x, idx_y)
57
+ g2_p1x[g2y[idx], g2x[idx]] = x[idx].float() * 2 / w1 - 1
58
+ g2_p1y[g2y[idx], g2x[idx]] = y[idx].float() * 2 / h1 - 1
59
+
60
+ certainty2 = F.grid_sample(
61
+ certainty1[None][None],
62
+ torch.stack([g2_p1x, g2_p1y], dim=2)[None],
63
+ mode="bilinear",
64
+ align_corners=False,
65
+ )[0]
66
+
67
+ white_im1 = torch.ones((h1, w1), device = device)
68
+ white_im2 = torch.ones((h2, w2), device = device)
69
+
70
+ certainty1 = F.avg_pool2d(certainty1[None], kernel_size=5, stride=1, padding=2)[0]
71
+ certainty2 = F.avg_pool2d(certainty2[None], kernel_size=5, stride=1, padding=2)[0]
72
+
73
+ vis_im1 = certainty1 * x1 + (1 - certainty1) * white_im1
74
+ vis_im2 = certainty2 * x2 + (1 - certainty2) * white_im2
75
+
76
+ tensor_to_pil(vis_im1, unnormalize=False).save(args.save_A_path)
77
+ tensor_to_pil(vis_im2, unnormalize=False).save(args.save_B_path)
submodules/RoMa/demo/gif/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
submodules/RoMa/experiments/eval_roma_outdoor.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from romatch.benchmarks import MegadepthDenseBenchmark
4
+ from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, HpatchesHomogBenchmark
5
+ from romatch.benchmarks import Mega1500PoseLibBenchmark
6
+
7
+ def test_mega_8_scenes(model, name):
8
+ mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
9
+ scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
10
+ 'mega_8_scenes_0025_0.1_0.3.npz',
11
+ 'mega_8_scenes_0021_0.1_0.3.npz',
12
+ 'mega_8_scenes_0008_0.1_0.3.npz',
13
+ 'mega_8_scenes_0032_0.1_0.3.npz',
14
+ 'mega_8_scenes_1589_0.1_0.3.npz',
15
+ 'mega_8_scenes_0063_0.1_0.3.npz',
16
+ 'mega_8_scenes_0024_0.1_0.3.npz',
17
+ 'mega_8_scenes_0019_0.3_0.5.npz',
18
+ 'mega_8_scenes_0025_0.3_0.5.npz',
19
+ 'mega_8_scenes_0021_0.3_0.5.npz',
20
+ 'mega_8_scenes_0008_0.3_0.5.npz',
21
+ 'mega_8_scenes_0032_0.3_0.5.npz',
22
+ 'mega_8_scenes_1589_0.3_0.5.npz',
23
+ 'mega_8_scenes_0063_0.3_0.5.npz',
24
+ 'mega_8_scenes_0024_0.3_0.5.npz'])
25
+ mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
26
+ print(mega_8_scenes_results)
27
+ json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
28
+
29
+ def test_mega1500(model, name):
30
+ mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
31
+ mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
32
+ json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
33
+
34
+ def test_mega1500_poselib(model, name):
35
+ mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth")
36
+ mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
37
+ json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
38
+
39
+ def test_mega_dense(model, name):
40
+ megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000)
41
+ megadense_results = megadense_benchmark.benchmark(model)
42
+ json.dump(megadense_results, open(f"results/mega_dense_{name}.json", "w"))
43
+
44
+ def test_hpatches(model, name):
45
+ hpatches_benchmark = HpatchesHomogBenchmark("data/hpatches")
46
+ hpatches_results = hpatches_benchmark.benchmark(model)
47
+ json.dump(hpatches_results, open(f"results/hpatches_{name}.json", "w"))
48
+
49
+
50
+ if __name__ == "__main__":
51
+ from romatch import roma_outdoor
52
+ device = "cuda"
53
+ model = roma_outdoor(device = device, coarse_res = 672, upsample_res = 1344)
54
+ experiment_name = "roma_latest"
55
+ test_mega1500(model, experiment_name)
56
+ #test_mega1500_poselib(model, experiment_name)
57
+
submodules/RoMa/experiments/eval_tiny_roma_v1_outdoor.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from pathlib import Path
4
+ import json
5
+ from romatch.benchmarks import ScanNetBenchmark
6
+ from romatch.benchmarks import Mega1500PoseLibBenchmark, ScanNetPoselibBenchmark
7
+ from romatch.benchmarks import MegaDepthPoseEstimationBenchmark
8
+
9
+ def test_mega_8_scenes(model, name):
10
+ mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
11
+ scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
12
+ 'mega_8_scenes_0025_0.1_0.3.npz',
13
+ 'mega_8_scenes_0021_0.1_0.3.npz',
14
+ 'mega_8_scenes_0008_0.1_0.3.npz',
15
+ 'mega_8_scenes_0032_0.1_0.3.npz',
16
+ 'mega_8_scenes_1589_0.1_0.3.npz',
17
+ 'mega_8_scenes_0063_0.1_0.3.npz',
18
+ 'mega_8_scenes_0024_0.1_0.3.npz',
19
+ 'mega_8_scenes_0019_0.3_0.5.npz',
20
+ 'mega_8_scenes_0025_0.3_0.5.npz',
21
+ 'mega_8_scenes_0021_0.3_0.5.npz',
22
+ 'mega_8_scenes_0008_0.3_0.5.npz',
23
+ 'mega_8_scenes_0032_0.3_0.5.npz',
24
+ 'mega_8_scenes_1589_0.3_0.5.npz',
25
+ 'mega_8_scenes_0063_0.3_0.5.npz',
26
+ 'mega_8_scenes_0024_0.3_0.5.npz'])
27
+ mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
28
+ print(mega_8_scenes_results)
29
+ json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
30
+
31
+ def test_mega1500(model, name):
32
+ mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
33
+ mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
34
+ json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
35
+
36
+ def test_mega1500_poselib(model, name):
37
+ #model.exact_softmax = True
38
+ mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1)
39
+ mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
40
+ json.dump(mega1500_results, open(f"results/mega1500_poselib_{name}.json", "w"))
41
+
42
+ def test_mega_8_scenes_poselib(model, name):
43
+ mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1,
44
+ scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
45
+ 'mega_8_scenes_0025_0.1_0.3.npz',
46
+ 'mega_8_scenes_0021_0.1_0.3.npz',
47
+ 'mega_8_scenes_0008_0.1_0.3.npz',
48
+ 'mega_8_scenes_0032_0.1_0.3.npz',
49
+ 'mega_8_scenes_1589_0.1_0.3.npz',
50
+ 'mega_8_scenes_0063_0.1_0.3.npz',
51
+ 'mega_8_scenes_0024_0.1_0.3.npz',
52
+ 'mega_8_scenes_0019_0.3_0.5.npz',
53
+ 'mega_8_scenes_0025_0.3_0.5.npz',
54
+ 'mega_8_scenes_0021_0.3_0.5.npz',
55
+ 'mega_8_scenes_0008_0.3_0.5.npz',
56
+ 'mega_8_scenes_0032_0.3_0.5.npz',
57
+ 'mega_8_scenes_1589_0.3_0.5.npz',
58
+ 'mega_8_scenes_0063_0.3_0.5.npz',
59
+ 'mega_8_scenes_0024_0.3_0.5.npz'])
60
+ mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
61
+ json.dump(mega1500_results, open(f"results/mega_8_scenes_poselib_{name}.json", "w"))
62
+
63
+ def test_scannet_poselib(model, name):
64
+ scannet_benchmark = ScanNetPoselibBenchmark("data/scannet")
65
+ scannet_results = scannet_benchmark.benchmark(model)
66
+ json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
67
+
68
+ def test_scannet(model, name):
69
+ scannet_benchmark = ScanNetBenchmark("data/scannet")
70
+ scannet_results = scannet_benchmark.benchmark(model)
71
+ json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
72
+
73
+ if __name__ == "__main__":
74
+ os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
75
+ os.environ["OMP_NUM_THREADS"] = "16"
76
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
77
+ from romatch import tiny_roma_v1_outdoor
78
+
79
+ experiment_name = Path(__file__).stem
80
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
81
+ model = tiny_roma_v1_outdoor(device)
82
+ #test_mega1500_poselib(model, experiment_name)
83
+ test_mega_8_scenes_poselib(model, experiment_name)
84
+
submodules/RoMa/experiments/roma_indoor.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from argparse import ArgumentParser
4
+
5
+ from torch import nn
6
+ from torch.utils.data import ConcatDataset
7
+ import torch.distributed as dist
8
+ from torch.nn.parallel import DistributedDataParallel as DDP
9
+
10
+ import json
11
+ import wandb
12
+ from tqdm import tqdm
13
+
14
+ from romatch.benchmarks import MegadepthDenseBenchmark
15
+ from romatch.datasets.megadepth import MegadepthBuilder
16
+ from romatch.datasets.scannet import ScanNetBuilder
17
+ from romatch.losses.robust_loss import RobustLosses
18
+ from romatch.benchmarks import MegadepthDenseBenchmark, ScanNetBenchmark
19
+ from romatch.train.train import train_k_steps
20
+ from romatch.models.matcher import *
21
+ from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention
22
+ from romatch.models.encoders import *
23
+ from romatch.checkpointing import CheckPoint
24
+
25
+ resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6)}
26
+
27
+ def get_model(pretrained_backbone=True, resolution = "medium", **kwargs):
28
+ gp_dim = 512
29
+ feat_dim = 512
30
+ decoder_dim = gp_dim + feat_dim
31
+ cls_to_coord_res = 64
32
+ coordinate_decoder = TransformerDecoder(
33
+ nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]),
34
+ decoder_dim,
35
+ cls_to_coord_res**2 + 1,
36
+ is_classifier=True,
37
+ amp = True,
38
+ pos_enc = False,)
39
+ dw = True
40
+ hidden_blocks = 8
41
+ kernel_size = 5
42
+ displacement_emb = "linear"
43
+ disable_local_corr_grad = True
44
+
45
+ conv_refiner = nn.ModuleDict(
46
+ {
47
+ "16": ConvRefiner(
48
+ 2 * 512+128+(2*7+1)**2,
49
+ 2 * 512+128+(2*7+1)**2,
50
+ 2 + 1,
51
+ kernel_size=kernel_size,
52
+ dw=dw,
53
+ hidden_blocks=hidden_blocks,
54
+ displacement_emb=displacement_emb,
55
+ displacement_emb_dim=128,
56
+ local_corr_radius = 7,
57
+ corr_in_other = True,
58
+ amp = True,
59
+ disable_local_corr_grad = disable_local_corr_grad,
60
+ bn_momentum = 0.01,
61
+ ),
62
+ "8": ConvRefiner(
63
+ 2 * 512+64+(2*3+1)**2,
64
+ 2 * 512+64+(2*3+1)**2,
65
+ 2 + 1,
66
+ kernel_size=kernel_size,
67
+ dw=dw,
68
+ hidden_blocks=hidden_blocks,
69
+ displacement_emb=displacement_emb,
70
+ displacement_emb_dim=64,
71
+ local_corr_radius = 3,
72
+ corr_in_other = True,
73
+ amp = True,
74
+ disable_local_corr_grad = disable_local_corr_grad,
75
+ bn_momentum = 0.01,
76
+ ),
77
+ "4": ConvRefiner(
78
+ 2 * 256+32+(2*2+1)**2,
79
+ 2 * 256+32+(2*2+1)**2,
80
+ 2 + 1,
81
+ kernel_size=kernel_size,
82
+ dw=dw,
83
+ hidden_blocks=hidden_blocks,
84
+ displacement_emb=displacement_emb,
85
+ displacement_emb_dim=32,
86
+ local_corr_radius = 2,
87
+ corr_in_other = True,
88
+ amp = True,
89
+ disable_local_corr_grad = disable_local_corr_grad,
90
+ bn_momentum = 0.01,
91
+ ),
92
+ "2": ConvRefiner(
93
+ 2 * 64+16,
94
+ 128+16,
95
+ 2 + 1,
96
+ kernel_size=kernel_size,
97
+ dw=dw,
98
+ hidden_blocks=hidden_blocks,
99
+ displacement_emb=displacement_emb,
100
+ displacement_emb_dim=16,
101
+ amp = True,
102
+ disable_local_corr_grad = disable_local_corr_grad,
103
+ bn_momentum = 0.01,
104
+ ),
105
+ "1": ConvRefiner(
106
+ 2 * 9 + 6,
107
+ 24,
108
+ 2 + 1,
109
+ kernel_size=kernel_size,
110
+ dw=dw,
111
+ hidden_blocks = hidden_blocks,
112
+ displacement_emb = displacement_emb,
113
+ displacement_emb_dim = 6,
114
+ amp = True,
115
+ disable_local_corr_grad = disable_local_corr_grad,
116
+ bn_momentum = 0.01,
117
+ ),
118
+ }
119
+ )
120
+ kernel_temperature = 0.2
121
+ learn_temperature = False
122
+ no_cov = True
123
+ kernel = CosKernel
124
+ only_attention = False
125
+ basis = "fourier"
126
+ gp16 = GP(
127
+ kernel,
128
+ T=kernel_temperature,
129
+ learn_temperature=learn_temperature,
130
+ only_attention=only_attention,
131
+ gp_dim=gp_dim,
132
+ basis=basis,
133
+ no_cov=no_cov,
134
+ )
135
+ gps = nn.ModuleDict({"16": gp16})
136
+ proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
137
+ proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
138
+ proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
139
+ proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
140
+ proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
141
+ proj = nn.ModuleDict({
142
+ "16": proj16,
143
+ "8": proj8,
144
+ "4": proj4,
145
+ "2": proj2,
146
+ "1": proj1,
147
+ })
148
+ displacement_dropout_p = 0.0
149
+ gm_warp_dropout_p = 0.0
150
+ decoder = Decoder(coordinate_decoder,
151
+ gps,
152
+ proj,
153
+ conv_refiner,
154
+ detach=True,
155
+ scales=["16", "8", "4", "2", "1"],
156
+ displacement_dropout_p = displacement_dropout_p,
157
+ gm_warp_dropout_p = gm_warp_dropout_p)
158
+ h,w = resolutions[resolution]
159
+ encoder = CNNandDinov2(
160
+ cnn_kwargs = dict(
161
+ pretrained=pretrained_backbone,
162
+ amp = True),
163
+ amp = True,
164
+ use_vgg = True,
165
+ )
166
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w, alpha=1, beta=0,**kwargs)
167
+ return matcher
168
+
169
+ def train(args):
170
+ dist.init_process_group('nccl')
171
+ #torch._dynamo.config.verbose=True
172
+ gpus = int(os.environ['WORLD_SIZE'])
173
+ # create model and move it to GPU with id rank
174
+ rank = dist.get_rank()
175
+ print(f"Start running DDP on rank {rank}")
176
+ device_id = rank % torch.cuda.device_count()
177
+ romatch.LOCAL_RANK = device_id
178
+ torch.cuda.set_device(device_id)
179
+
180
+ resolution = args.train_resolution
181
+ wandb_log = not args.dont_log_wandb
182
+ experiment_name = os.path.splitext(os.path.basename(__file__))[0]
183
+ wandb_mode = "online" if wandb_log and rank == 0 and False else "disabled"
184
+ wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
185
+ checkpoint_dir = "workspace/checkpoints/"
186
+ h,w = resolutions[resolution]
187
+ model = get_model(pretrained_backbone=True, resolution=resolution, attenuate_cert = False).to(device_id)
188
+ # Num steps
189
+ global_step = 0
190
+ batch_size = args.gpu_batch_size
191
+ step_size = gpus*batch_size
192
+ romatch.STEP_SIZE = step_size
193
+
194
+ N = (32 * 250000) # 250k steps of batch size 32
195
+ # checkpoint every
196
+ k = 25000 // romatch.STEP_SIZE
197
+
198
+ # Data
199
+ mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
200
+ use_horizontal_flip_aug = True
201
+ rot_prob = 0
202
+ depth_interpolation_mode = "bilinear"
203
+ megadepth_train1 = mega.build_scenes(
204
+ split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
205
+ ht=h,wt=w,
206
+ )
207
+ megadepth_train2 = mega.build_scenes(
208
+ split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
209
+ ht=h,wt=w,
210
+ )
211
+ megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
212
+ mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
213
+
214
+ scannet = ScanNetBuilder(data_root="data/scannet")
215
+ scannet_train = scannet.build_scenes(split="train", ht=h, wt=w, use_horizontal_flip_aug = use_horizontal_flip_aug)
216
+ scannet_train = ConcatDataset(scannet_train)
217
+ scannet_ws = scannet.weight_scenes(scannet_train, alpha=0.75)
218
+
219
+ # Loss and optimizer
220
+ depth_loss_scannet = RobustLosses(
221
+ ce_weight=0.0,
222
+ local_dist={1:4, 2:4, 4:8, 8:8},
223
+ local_largest_scale=8,
224
+ depth_interpolation_mode=depth_interpolation_mode,
225
+ alpha = 0.5,
226
+ c = 1e-4,)
227
+ # Loss and optimizer
228
+ depth_loss_mega = RobustLosses(
229
+ ce_weight=0.01,
230
+ local_dist={1:4, 2:4, 4:8, 8:8},
231
+ local_largest_scale=8,
232
+ depth_interpolation_mode=depth_interpolation_mode,
233
+ alpha = 0.5,
234
+ c = 1e-4,)
235
+ parameters = [
236
+ {"params": model.encoder.parameters(), "lr": romatch.STEP_SIZE * 5e-6 / 8},
237
+ {"params": model.decoder.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8},
238
+ ]
239
+ optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
240
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
241
+ optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10])
242
+ megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
243
+ checkpointer = CheckPoint(checkpoint_dir, experiment_name)
244
+ model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
245
+ romatch.GLOBAL_STEP = global_step
246
+ ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters = False, gradient_as_bucket_view=True)
247
+ grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
248
+ grad_clip_norm = 0.01
249
+ for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE):
250
+ mega_sampler = torch.utils.data.WeightedRandomSampler(
251
+ mega_ws, num_samples = batch_size * k, replacement=False
252
+ )
253
+ mega_dataloader = iter(
254
+ torch.utils.data.DataLoader(
255
+ megadepth_train,
256
+ batch_size = batch_size,
257
+ sampler = mega_sampler,
258
+ num_workers = 8,
259
+ )
260
+ )
261
+ scannet_ws_sampler = torch.utils.data.WeightedRandomSampler(
262
+ scannet_ws, num_samples=batch_size * k, replacement=False
263
+ )
264
+ scannet_dataloader = iter(
265
+ torch.utils.data.DataLoader(
266
+ scannet_train,
267
+ batch_size=batch_size,
268
+ sampler=scannet_ws_sampler,
269
+ num_workers=gpus * 8,
270
+ )
271
+ )
272
+ for n_k in tqdm(range(n, n + 2 * k, 2),disable = romatch.RANK > 0):
273
+ train_k_steps(
274
+ n_k, 1, mega_dataloader, ddp_model, depth_loss_mega, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False
275
+ )
276
+ train_k_steps(
277
+ n_k + 1, 1, scannet_dataloader, ddp_model, depth_loss_scannet, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False
278
+ )
279
+ checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP)
280
+ wandb.log(megadense_benchmark.benchmark(model), step = romatch.GLOBAL_STEP)
281
+
282
+ def test_scannet(model, name, resolution, sample_mode):
283
+ scannet_benchmark = ScanNetBenchmark("data/scannet")
284
+ scannet_results = scannet_benchmark.benchmark(model)
285
+ json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
286
+
287
+ if __name__ == "__main__":
288
+ import warnings
289
+ warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
290
+ warnings.filterwarnings('ignore')#, category=UserWarning)#, message='WARNING batched routines are designed for small sizes.')
291
+ os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
292
+ os.environ["OMP_NUM_THREADS"] = "16"
293
+
294
+ import romatch
295
+ parser = ArgumentParser()
296
+ parser.add_argument("--test", action='store_true')
297
+ parser.add_argument("--debug_mode", action='store_true')
298
+ parser.add_argument("--dont_log_wandb", action='store_true')
299
+ parser.add_argument("--train_resolution", default='medium')
300
+ parser.add_argument("--gpu_batch_size", default=4, type=int)
301
+ parser.add_argument("--wandb_entity", required = False)
302
+
303
+ args, _ = parser.parse_known_args()
304
+ romatch.DEBUG_MODE = args.debug_mode
305
+ if not args.test:
306
+ train(args)
307
+ experiment_name = os.path.splitext(os.path.basename(__file__))[0]
308
+ checkpoint_dir = "workspace/"
309
+ checkpoint_name = checkpoint_dir + experiment_name + ".pth"
310
+ test_resolution = "medium"
311
+ sample_mode = "threshold_balanced"
312
+ symmetric = True
313
+ upsample_preds = False
314
+ attenuate_cert = True
315
+
316
+ model = get_model(pretrained_backbone=False, resolution = test_resolution, sample_mode = sample_mode, upsample_preds = upsample_preds, symmetric=symmetric, name=experiment_name, attenuate_cert = attenuate_cert)
317
+ model = model.cuda()
318
+ states = torch.load(checkpoint_name)
319
+ model.load_state_dict(states["model"])
320
+ test_scannet(model, experiment_name, resolution = test_resolution, sample_mode = sample_mode)
submodules/RoMa/experiments/train_roma_outdoor.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from argparse import ArgumentParser
4
+
5
+ from torch import nn
6
+ from torch.utils.data import ConcatDataset
7
+ import torch.distributed as dist
8
+ from torch.nn.parallel import DistributedDataParallel as DDP
9
+ import json
10
+ import wandb
11
+
12
+ from romatch.benchmarks import MegadepthDenseBenchmark
13
+ from romatch.datasets.megadepth import MegadepthBuilder
14
+ from romatch.losses.robust_loss import RobustLosses
15
+ from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, MegadepthDenseBenchmark, HpatchesHomogBenchmark
16
+
17
+ from romatch.train.train import train_k_steps
18
+ from romatch.models.matcher import *
19
+ from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention
20
+ from romatch.models.encoders import *
21
+ from romatch.checkpointing import CheckPoint
22
+
23
+ resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6)}
24
+
25
+ def get_model(pretrained_backbone=True, resolution = "medium", **kwargs):
26
+ import warnings
27
+ warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
28
+ gp_dim = 512
29
+ feat_dim = 512
30
+ decoder_dim = gp_dim + feat_dim
31
+ cls_to_coord_res = 64
32
+ coordinate_decoder = TransformerDecoder(
33
+ nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]),
34
+ decoder_dim,
35
+ cls_to_coord_res**2 + 1,
36
+ is_classifier=True,
37
+ amp = True,
38
+ pos_enc = False,)
39
+ dw = True
40
+ hidden_blocks = 8
41
+ kernel_size = 5
42
+ displacement_emb = "linear"
43
+ disable_local_corr_grad = True
44
+
45
+ conv_refiner = nn.ModuleDict(
46
+ {
47
+ "16": ConvRefiner(
48
+ 2 * 512+128+(2*7+1)**2,
49
+ 2 * 512+128+(2*7+1)**2,
50
+ 2 + 1,
51
+ kernel_size=kernel_size,
52
+ dw=dw,
53
+ hidden_blocks=hidden_blocks,
54
+ displacement_emb=displacement_emb,
55
+ displacement_emb_dim=128,
56
+ local_corr_radius = 7,
57
+ corr_in_other = True,
58
+ amp = True,
59
+ disable_local_corr_grad = disable_local_corr_grad,
60
+ bn_momentum = 0.01,
61
+ ),
62
+ "8": ConvRefiner(
63
+ 2 * 512+64+(2*3+1)**2,
64
+ 2 * 512+64+(2*3+1)**2,
65
+ 2 + 1,
66
+ kernel_size=kernel_size,
67
+ dw=dw,
68
+ hidden_blocks=hidden_blocks,
69
+ displacement_emb=displacement_emb,
70
+ displacement_emb_dim=64,
71
+ local_corr_radius = 3,
72
+ corr_in_other = True,
73
+ amp = True,
74
+ disable_local_corr_grad = disable_local_corr_grad,
75
+ bn_momentum = 0.01,
76
+ ),
77
+ "4": ConvRefiner(
78
+ 2 * 256+32+(2*2+1)**2,
79
+ 2 * 256+32+(2*2+1)**2,
80
+ 2 + 1,
81
+ kernel_size=kernel_size,
82
+ dw=dw,
83
+ hidden_blocks=hidden_blocks,
84
+ displacement_emb=displacement_emb,
85
+ displacement_emb_dim=32,
86
+ local_corr_radius = 2,
87
+ corr_in_other = True,
88
+ amp = True,
89
+ disable_local_corr_grad = disable_local_corr_grad,
90
+ bn_momentum = 0.01,
91
+ ),
92
+ "2": ConvRefiner(
93
+ 2 * 64+16,
94
+ 128+16,
95
+ 2 + 1,
96
+ kernel_size=kernel_size,
97
+ dw=dw,
98
+ hidden_blocks=hidden_blocks,
99
+ displacement_emb=displacement_emb,
100
+ displacement_emb_dim=16,
101
+ amp = True,
102
+ disable_local_corr_grad = disable_local_corr_grad,
103
+ bn_momentum = 0.01,
104
+ ),
105
+ "1": ConvRefiner(
106
+ 2 * 9 + 6,
107
+ 24,
108
+ 2 + 1,
109
+ kernel_size=kernel_size,
110
+ dw=dw,
111
+ hidden_blocks = hidden_blocks,
112
+ displacement_emb = displacement_emb,
113
+ displacement_emb_dim = 6,
114
+ amp = True,
115
+ disable_local_corr_grad = disable_local_corr_grad,
116
+ bn_momentum = 0.01,
117
+ ),
118
+ }
119
+ )
120
+ kernel_temperature = 0.2
121
+ learn_temperature = False
122
+ no_cov = True
123
+ kernel = CosKernel
124
+ only_attention = False
125
+ basis = "fourier"
126
+ gp16 = GP(
127
+ kernel,
128
+ T=kernel_temperature,
129
+ learn_temperature=learn_temperature,
130
+ only_attention=only_attention,
131
+ gp_dim=gp_dim,
132
+ basis=basis,
133
+ no_cov=no_cov,
134
+ )
135
+ gps = nn.ModuleDict({"16": gp16})
136
+ proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
137
+ proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
138
+ proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
139
+ proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
140
+ proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
141
+ proj = nn.ModuleDict({
142
+ "16": proj16,
143
+ "8": proj8,
144
+ "4": proj4,
145
+ "2": proj2,
146
+ "1": proj1,
147
+ })
148
+ displacement_dropout_p = 0.0
149
+ gm_warp_dropout_p = 0.0
150
+ decoder = Decoder(coordinate_decoder,
151
+ gps,
152
+ proj,
153
+ conv_refiner,
154
+ detach=True,
155
+ scales=["16", "8", "4", "2", "1"],
156
+ displacement_dropout_p = displacement_dropout_p,
157
+ gm_warp_dropout_p = gm_warp_dropout_p)
158
+ h,w = resolutions[resolution]
159
+ encoder = CNNandDinov2(
160
+ cnn_kwargs = dict(
161
+ pretrained=pretrained_backbone,
162
+ amp = True),
163
+ amp = True,
164
+ use_vgg = True,
165
+ )
166
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w,**kwargs)
167
+ return matcher
168
+
169
+ def train(args):
170
+ dist.init_process_group('nccl')
171
+ #torch._dynamo.config.verbose=True
172
+ gpus = int(os.environ['WORLD_SIZE'])
173
+ # create model and move it to GPU with id rank
174
+ rank = dist.get_rank()
175
+ print(f"Start running DDP on rank {rank}")
176
+ device_id = rank % torch.cuda.device_count()
177
+ romatch.LOCAL_RANK = device_id
178
+ torch.cuda.set_device(device_id)
179
+
180
+ resolution = args.train_resolution
181
+ wandb_log = not args.dont_log_wandb
182
+ experiment_name = os.path.splitext(os.path.basename(__file__))[0]
183
+ wandb_mode = "online" if wandb_log and rank == 0 else "disabled"
184
+ wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
185
+ checkpoint_dir = "workspace/checkpoints/"
186
+ h,w = resolutions[resolution]
187
+ model = get_model(pretrained_backbone=True, resolution=resolution, attenuate_cert = False).to(device_id)
188
+ # Num steps
189
+ global_step = 0
190
+ batch_size = args.gpu_batch_size
191
+ step_size = gpus*batch_size
192
+ romatch.STEP_SIZE = step_size
193
+
194
+ N = (32 * 250000) # 250k steps of batch size 32
195
+ # checkpoint every
196
+ k = 25000 // romatch.STEP_SIZE
197
+
198
+ # Data
199
+ mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
200
+ use_horizontal_flip_aug = True
201
+ rot_prob = 0
202
+ depth_interpolation_mode = "bilinear"
203
+ megadepth_train1 = mega.build_scenes(
204
+ split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
205
+ ht=h,wt=w,
206
+ )
207
+ megadepth_train2 = mega.build_scenes(
208
+ split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
209
+ ht=h,wt=w,
210
+ )
211
+ megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
212
+ mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
213
+ # Loss and optimizer
214
+ depth_loss = RobustLosses(
215
+ ce_weight=0.01,
216
+ local_dist={1:4, 2:4, 4:8, 8:8},
217
+ local_largest_scale=8,
218
+ depth_interpolation_mode=depth_interpolation_mode,
219
+ alpha = 0.5,
220
+ c = 1e-4,)
221
+ parameters = [
222
+ {"params": model.encoder.parameters(), "lr": romatch.STEP_SIZE * 5e-6 / 8},
223
+ {"params": model.decoder.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8},
224
+ ]
225
+ optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
226
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
227
+ optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10])
228
+ megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
229
+ checkpointer = CheckPoint(checkpoint_dir, experiment_name)
230
+ model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
231
+ romatch.GLOBAL_STEP = global_step
232
+ ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters = False, gradient_as_bucket_view=True)
233
+ grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
234
+ grad_clip_norm = 0.01
235
+ for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE):
236
+ mega_sampler = torch.utils.data.WeightedRandomSampler(
237
+ mega_ws, num_samples = batch_size * k, replacement=False
238
+ )
239
+ mega_dataloader = iter(
240
+ torch.utils.data.DataLoader(
241
+ megadepth_train,
242
+ batch_size = batch_size,
243
+ sampler = mega_sampler,
244
+ num_workers = 8,
245
+ )
246
+ )
247
+ train_k_steps(
248
+ n, k, mega_dataloader, ddp_model, depth_loss, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm,
249
+ )
250
+ checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP)
251
+ wandb.log(megadense_benchmark.benchmark(model), step = romatch.GLOBAL_STEP)
252
+
253
+ def test_mega_8_scenes(model, name):
254
+ mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
255
+ scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
256
+ 'mega_8_scenes_0025_0.1_0.3.npz',
257
+ 'mega_8_scenes_0021_0.1_0.3.npz',
258
+ 'mega_8_scenes_0008_0.1_0.3.npz',
259
+ 'mega_8_scenes_0032_0.1_0.3.npz',
260
+ 'mega_8_scenes_1589_0.1_0.3.npz',
261
+ 'mega_8_scenes_0063_0.1_0.3.npz',
262
+ 'mega_8_scenes_0024_0.1_0.3.npz',
263
+ 'mega_8_scenes_0019_0.3_0.5.npz',
264
+ 'mega_8_scenes_0025_0.3_0.5.npz',
265
+ 'mega_8_scenes_0021_0.3_0.5.npz',
266
+ 'mega_8_scenes_0008_0.3_0.5.npz',
267
+ 'mega_8_scenes_0032_0.3_0.5.npz',
268
+ 'mega_8_scenes_1589_0.3_0.5.npz',
269
+ 'mega_8_scenes_0063_0.3_0.5.npz',
270
+ 'mega_8_scenes_0024_0.3_0.5.npz'])
271
+ mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
272
+ print(mega_8_scenes_results)
273
+ json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
274
+
275
+ def test_mega1500(model, name):
276
+ mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
277
+ mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
278
+ json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
279
+
280
+ def test_mega_dense(model, name):
281
+ megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000)
282
+ megadense_results = megadense_benchmark.benchmark(model)
283
+ json.dump(megadense_results, open(f"results/mega_dense_{name}.json", "w"))
284
+
285
+ def test_hpatches(model, name):
286
+ hpatches_benchmark = HpatchesHomogBenchmark("data/hpatches")
287
+ hpatches_results = hpatches_benchmark.benchmark(model)
288
+ json.dump(hpatches_results, open(f"results/hpatches_{name}.json", "w"))
289
+
290
+
291
+ if __name__ == "__main__":
292
+ os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
293
+ os.environ["OMP_NUM_THREADS"] = "16"
294
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
295
+ import romatch
296
+ parser = ArgumentParser()
297
+ parser.add_argument("--only_test", action='store_true')
298
+ parser.add_argument("--debug_mode", action='store_true')
299
+ parser.add_argument("--dont_log_wandb", action='store_true')
300
+ parser.add_argument("--train_resolution", default='medium')
301
+ parser.add_argument("--gpu_batch_size", default=8, type=int)
302
+ parser.add_argument("--wandb_entity", required = False)
303
+
304
+ args, _ = parser.parse_known_args()
305
+ romatch.DEBUG_MODE = args.debug_mode
306
+ if not args.only_test:
307
+ train(args)
submodules/RoMa/experiments/train_tiny_roma_v1_outdoor.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import os
5
+ import torch
6
+ from argparse import ArgumentParser
7
+ from pathlib import Path
8
+ import math
9
+ import numpy as np
10
+
11
+ from torch import nn
12
+ from torch.utils.data import ConcatDataset
13
+ import torch.distributed as dist
14
+ from torch.nn.parallel import DistributedDataParallel as DDP
15
+ import json
16
+ import wandb
17
+ from PIL import Image
18
+ from torchvision.transforms import ToTensor
19
+
20
+ from romatch.benchmarks import MegadepthDenseBenchmark, ScanNetBenchmark
21
+ from romatch.benchmarks import Mega1500PoseLibBenchmark, ScanNetPoselibBenchmark
22
+ from romatch.datasets.megadepth import MegadepthBuilder
23
+ from romatch.losses.robust_loss_tiny_roma import RobustLosses
24
+ from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, MegadepthDenseBenchmark, HpatchesHomogBenchmark
25
+ from romatch.train.train import train_k_steps
26
+ from romatch.checkpointing import CheckPoint
27
+
28
+ resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6), "xfeat": (600,800), "big": (768, 1024)}
29
+
30
+ def kde(x, std = 0.1):
31
+ # use a gaussian kernel to estimate density
32
+ x = x.half() # Do it in half precision TODO: remove hardcoding
33
+ scores = (-torch.cdist(x,x)**2/(2*std**2)).exp()
34
+ density = scores.sum(dim=-1)
35
+ return density
36
+
37
+ class BasicLayer(nn.Module):
38
+ """
39
+ Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU
40
+ """
41
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False, relu = True):
42
+ super().__init__()
43
+ self.layer = nn.Sequential(
44
+ nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias),
45
+ nn.BatchNorm2d(out_channels, affine=False),
46
+ nn.ReLU(inplace = True) if relu else nn.Identity()
47
+ )
48
+
49
+ def forward(self, x):
50
+ return self.layer(x)
51
+
52
+ class XFeatModel(nn.Module):
53
+ """
54
+ Implementation of architecture described in
55
+ "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
56
+ """
57
+
58
+ def __init__(self, xfeat = None,
59
+ freeze_xfeat = True,
60
+ sample_mode = "threshold_balanced",
61
+ symmetric = False,
62
+ exact_softmax = False):
63
+ super().__init__()
64
+ if xfeat is None:
65
+ xfeat = torch.hub.load('verlab/accelerated_features', 'XFeat', pretrained = True, top_k = 4096).net
66
+ del xfeat.heatmap_head, xfeat.keypoint_head, xfeat.fine_matcher
67
+ if freeze_xfeat:
68
+ xfeat.train(False)
69
+ self.xfeat = [xfeat]# hide params from ddp
70
+ else:
71
+ self.xfeat = nn.ModuleList([xfeat])
72
+ self.freeze_xfeat = freeze_xfeat
73
+ match_dim = 256
74
+ self.coarse_matcher = nn.Sequential(
75
+ BasicLayer(64+64+2, match_dim,),
76
+ BasicLayer(match_dim, match_dim,),
77
+ BasicLayer(match_dim, match_dim,),
78
+ BasicLayer(match_dim, match_dim,),
79
+ nn.Conv2d(match_dim, 3, kernel_size=1, bias=True, padding=0))
80
+ fine_match_dim = 64
81
+ self.fine_matcher = nn.Sequential(
82
+ BasicLayer(24+24+2, fine_match_dim,),
83
+ BasicLayer(fine_match_dim, fine_match_dim,),
84
+ BasicLayer(fine_match_dim, fine_match_dim,),
85
+ BasicLayer(fine_match_dim, fine_match_dim,),
86
+ nn.Conv2d(fine_match_dim, 3, kernel_size=1, bias=True, padding=0),)
87
+ self.sample_mode = sample_mode
88
+ self.sample_thresh = 0.2
89
+ self.symmetric = symmetric
90
+ self.exact_softmax = exact_softmax
91
+
92
+ @property
93
+ def device(self):
94
+ return self.fine_matcher[-1].weight.device
95
+
96
+ def preprocess_tensor(self, x):
97
+ """ Guarantee that image is divisible by 32 to avoid aliasing artifacts. """
98
+ H, W = x.shape[-2:]
99
+ _H, _W = (H//32) * 32, (W//32) * 32
100
+ rh, rw = H/_H, W/_W
101
+
102
+ x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False)
103
+ return x, rh, rw
104
+
105
+ def forward_single(self, x):
106
+ with torch.inference_mode(self.freeze_xfeat or not self.training):
107
+ xfeat = self.xfeat[0]
108
+ with torch.no_grad():
109
+ x = x.mean(dim=1, keepdim = True)
110
+ x = xfeat.norm(x)
111
+
112
+ #main backbone
113
+ x1 = xfeat.block1(x)
114
+ x2 = xfeat.block2(x1 + xfeat.skip1(x))
115
+ x3 = xfeat.block3(x2)
116
+ x4 = xfeat.block4(x3)
117
+ x5 = xfeat.block5(x4)
118
+ x4 = F.interpolate(x4, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
119
+ x5 = F.interpolate(x5, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
120
+ feats = xfeat.block_fusion( x3 + x4 + x5 )
121
+ if self.freeze_xfeat:
122
+ return x2.clone(), feats.clone()
123
+ return x2, feats
124
+
125
+ def to_pixel_coordinates(self, coords, H_A, W_A, H_B = None, W_B = None):
126
+ if coords.shape[-1] == 2:
127
+ return self._to_pixel_coordinates(coords, H_A, W_A)
128
+
129
+ if isinstance(coords, (list, tuple)):
130
+ kpts_A, kpts_B = coords[0], coords[1]
131
+ else:
132
+ kpts_A, kpts_B = coords[...,:2], coords[...,2:]
133
+ return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(kpts_B, H_B, W_B)
134
+
135
+ def _to_pixel_coordinates(self, coords, H, W):
136
+ kpts = torch.stack((W/2 * (coords[...,0]+1), H/2 * (coords[...,1]+1)),axis=-1)
137
+ return kpts
138
+
139
+ def pos_embed(self, corr_volume: torch.Tensor):
140
+ B, H1, W1, H0, W0 = corr_volume.shape
141
+ grid = torch.stack(
142
+ torch.meshgrid(
143
+ torch.linspace(-1+1/W1,1-1/W1, W1),
144
+ torch.linspace(-1+1/H1,1-1/H1, H1),
145
+ indexing = "xy"),
146
+ dim = -1).float().to(corr_volume).reshape(H1*W1, 2)
147
+ down = 4
148
+ if not self.training and not self.exact_softmax:
149
+ grid_lr = torch.stack(
150
+ torch.meshgrid(
151
+ torch.linspace(-1+down/W1,1-down/W1, W1//down),
152
+ torch.linspace(-1+down/H1,1-down/H1, H1//down),
153
+ indexing = "xy"),
154
+ dim = -1).float().to(corr_volume).reshape(H1*W1 //down**2, 2)
155
+ cv = corr_volume
156
+ best_match = cv.reshape(B,H1*W1,H0,W0).amax(dim=1) # B, HW, H, W
157
+ P_lowres = torch.cat((cv[:,::down,::down].reshape(B,H1*W1 // down**2,H0,W0), best_match[:,None]),dim=1).softmax(dim=1)
158
+ pos_embeddings = torch.einsum('bchw,cd->bdhw', P_lowres[:,:-1], grid_lr)
159
+ pos_embeddings += P_lowres[:,-1] * grid[best_match].permute(0,3,1,2)
160
+ else:
161
+ P = corr_volume.reshape(B,H1*W1,H0,W0).softmax(dim=1) # B, HW, H, W
162
+ pos_embeddings = torch.einsum('bchw,cd->bdhw', P, grid)
163
+ return pos_embeddings
164
+
165
+ def visualize_warp(self, warp, certainty, im_A = None, im_B = None,
166
+ im_A_path = None, im_B_path = None, symmetric = True, save_path = None, unnormalize = False):
167
+ device = warp.device
168
+ H,W2,_ = warp.shape
169
+ W = W2//2 if symmetric else W2
170
+ if im_A is None:
171
+ from PIL import Image
172
+ im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
173
+ if not isinstance(im_A, torch.Tensor):
174
+ im_A = im_A.resize((W,H))
175
+ im_B = im_B.resize((W,H))
176
+ x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
177
+ if symmetric:
178
+ x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
179
+ else:
180
+ if symmetric:
181
+ x_A = im_A
182
+ x_B = im_B
183
+ im_A_transfer_rgb = F.grid_sample(
184
+ x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
185
+ )[0]
186
+ if symmetric:
187
+ im_B_transfer_rgb = F.grid_sample(
188
+ x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
189
+ )[0]
190
+ warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2)
191
+ white_im = torch.ones((H,2*W),device=device)
192
+ else:
193
+ warp_im = im_A_transfer_rgb
194
+ white_im = torch.ones((H, W), device = device)
195
+ vis_im = certainty * warp_im + (1 - certainty) * white_im
196
+ if save_path is not None:
197
+ from romatch.utils import tensor_to_pil
198
+ tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path)
199
+ return vis_im
200
+
201
+ def corr_volume(self, feat0, feat1):
202
+ """
203
+ input:
204
+ feat0 -> torch.Tensor(B, C, H, W)
205
+ feat1 -> torch.Tensor(B, C, H, W)
206
+ return:
207
+ corr_volume -> torch.Tensor(B, H, W, H, W)
208
+ """
209
+ B, C, H0, W0 = feat0.shape
210
+ B, C, H1, W1 = feat1.shape
211
+ feat0 = feat0.view(B, C, H0*W0)
212
+ feat1 = feat1.view(B, C, H1*W1)
213
+ corr_volume = torch.einsum('bci,bcj->bji', feat0, feat1).reshape(B, H1, W1, H0 , W0)/math.sqrt(C) #16*16*16
214
+ return corr_volume
215
+
216
+ @torch.inference_mode()
217
+ def match_from_path(self, im0_path, im1_path):
218
+ device = self.device
219
+ im0 = ToTensor()(Image.open(im0_path))[None].to(device)
220
+ im1 = ToTensor()(Image.open(im1_path))[None].to(device)
221
+ return self.match(im0, im1, batched = False)
222
+
223
+ @torch.inference_mode()
224
+ def match(self, im0, im1, *args, batched = True):
225
+ # stupid
226
+ if isinstance(im0, (str, Path)):
227
+ return self.match_from_path(im0, im1)
228
+ elif isinstance(im0, Image.Image):
229
+ batched = False
230
+ device = self.device
231
+ im0 = ToTensor()(im0)[None].to(device)
232
+ im1 = ToTensor()(im1)[None].to(device)
233
+
234
+ B,C,H0,W0 = im0.shape
235
+ B,C,H1,W1 = im1.shape
236
+ self.train(False)
237
+ corresps = self.forward({"im_A":im0, "im_B":im1})
238
+ #return 1,1
239
+ flow = F.interpolate(
240
+ corresps[4]["flow"],
241
+ size = (H0, W0),
242
+ mode = "bilinear", align_corners = False).permute(0,2,3,1).reshape(B,H0,W0,2)
243
+ grid = torch.stack(
244
+ torch.meshgrid(
245
+ torch.linspace(-1+1/W0,1-1/W0, W0),
246
+ torch.linspace(-1+1/H0,1-1/H0, H0),
247
+ indexing = "xy"),
248
+ dim = -1).float().to(flow.device).expand(B, H0, W0, 2)
249
+
250
+ certainty = F.interpolate(corresps[4]["certainty"], size = (H0,W0), mode = "bilinear", align_corners = False)
251
+ warp, cert = torch.cat((grid, flow), dim = -1), certainty[:,0].sigmoid()
252
+ if batched:
253
+ return warp, cert
254
+ else:
255
+ return warp[0], cert[0]
256
+
257
+ def sample(
258
+ self,
259
+ matches,
260
+ certainty,
261
+ num=10000,
262
+ ):
263
+ if "threshold" in self.sample_mode:
264
+ upper_thresh = self.sample_thresh
265
+ certainty = certainty.clone()
266
+ certainty[certainty > upper_thresh] = 1
267
+ matches, certainty = (
268
+ matches.reshape(-1, 4),
269
+ certainty.reshape(-1),
270
+ )
271
+ expansion_factor = 4 if "balanced" in self.sample_mode else 1
272
+ good_samples = torch.multinomial(certainty,
273
+ num_samples = min(expansion_factor*num, len(certainty)),
274
+ replacement=False)
275
+ good_matches, good_certainty = matches[good_samples], certainty[good_samples]
276
+ if "balanced" not in self.sample_mode:
277
+ return good_matches, good_certainty
278
+ density = kde(good_matches, std=0.1)
279
+ p = 1 / (density+1)
280
+ p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
281
+ balanced_samples = torch.multinomial(p,
282
+ num_samples = min(num,len(good_certainty)),
283
+ replacement=False)
284
+ return good_matches[balanced_samples], good_certainty[balanced_samples]
285
+
286
+ def forward(self, batch):
287
+ """
288
+ input:
289
+ x -> torch.Tensor(B, C, H, W) grayscale or rgb images
290
+ return:
291
+
292
+ """
293
+ im0 = batch["im_A"]
294
+ im1 = batch["im_B"]
295
+ corresps = {}
296
+ im0, rh0, rw0 = self.preprocess_tensor(im0)
297
+ im1, rh1, rw1 = self.preprocess_tensor(im1)
298
+ B, C, H0, W0 = im0.shape
299
+ B, C, H1, W1 = im1.shape
300
+ to_normalized = torch.tensor((2/W1, 2/H1, 1)).to(im0.device)[None,:,None,None]
301
+
302
+ if im0.shape[-2:] == im1.shape[-2:]:
303
+ x = torch.cat([im0, im1], dim=0)
304
+ x = self.forward_single(x)
305
+ feats_x0_c, feats_x1_c = x[1].chunk(2)
306
+ feats_x0_f, feats_x1_f = x[0].chunk(2)
307
+ else:
308
+ feats_x0_f, feats_x0_c = self.forward_single(im0)
309
+ feats_x1_f, feats_x1_c = self.forward_single(im1)
310
+ corr_volume = self.corr_volume(feats_x0_c, feats_x1_c)
311
+ coarse_warp = self.pos_embed(corr_volume)
312
+ coarse_matches = torch.cat((coarse_warp, torch.zeros_like(coarse_warp[:,-1:])), dim=1)
313
+ feats_x1_c_warped = F.grid_sample(feats_x1_c, coarse_matches.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
314
+ coarse_matches_delta = self.coarse_matcher(torch.cat((feats_x0_c, feats_x1_c_warped, coarse_warp), dim=1))
315
+ coarse_matches = coarse_matches + coarse_matches_delta * to_normalized
316
+ corresps[8] = {"flow": coarse_matches[:,:2], "certainty": coarse_matches[:,2:]}
317
+ coarse_matches_up = F.interpolate(coarse_matches, size = feats_x0_f.shape[-2:], mode = "bilinear", align_corners = False)
318
+ coarse_matches_up_detach = coarse_matches_up.detach()#note the detach
319
+ feats_x1_f_warped = F.grid_sample(feats_x1_f, coarse_matches_up_detach.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
320
+ fine_matches_delta = self.fine_matcher(torch.cat((feats_x0_f, feats_x1_f_warped, coarse_matches_up_detach[:,:2]), dim=1))
321
+ fine_matches = coarse_matches_up_detach+fine_matches_delta * to_normalized
322
+ corresps[4] = {"flow": fine_matches[:,:2], "certainty": fine_matches[:,2:]}
323
+ return corresps
324
+
325
+
326
+
327
+
328
+
329
+ def train(args):
330
+ rank = 0
331
+ gpus = 1
332
+ device_id = rank % torch.cuda.device_count()
333
+ romatch.LOCAL_RANK = 0
334
+ torch.cuda.set_device(device_id)
335
+
336
+ resolution = "big"
337
+ wandb_log = not args.dont_log_wandb
338
+ experiment_name = Path(__file__).stem
339
+ wandb_mode = "online" if wandb_log and rank == 0 else "disabled"
340
+ wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
341
+ checkpoint_dir = "workspace/checkpoints/"
342
+ h,w = resolutions[resolution]
343
+ model = XFeatModel(freeze_xfeat = False).to(device_id)
344
+ # Num steps
345
+ global_step = 0
346
+ batch_size = args.gpu_batch_size
347
+ step_size = gpus*batch_size
348
+ romatch.STEP_SIZE = step_size
349
+
350
+ N = 2_000_000 # 2M pairs
351
+ # checkpoint every
352
+ k = 25000 // romatch.STEP_SIZE
353
+
354
+ # Data
355
+ mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
356
+ use_horizontal_flip_aug = True
357
+ normalize = False # don't imgnet normalize
358
+ rot_prob = 0
359
+ depth_interpolation_mode = "bilinear"
360
+ megadepth_train1 = mega.build_scenes(
361
+ split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
362
+ ht=h,wt=w, normalize = normalize
363
+ )
364
+ megadepth_train2 = mega.build_scenes(
365
+ split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
366
+ ht=h,wt=w, normalize = normalize
367
+ )
368
+ megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
369
+ mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
370
+ # Loss and optimizer
371
+ depth_loss = RobustLosses(
372
+ ce_weight=0.01,
373
+ local_dist={4:4},
374
+ depth_interpolation_mode=depth_interpolation_mode,
375
+ alpha = {4:0.15, 8:0.15},
376
+ c = 1e-4,
377
+ epe_mask_prob_th = 0.001,
378
+ )
379
+ parameters = [
380
+ {"params": model.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8},
381
+ ]
382
+ optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
383
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
384
+ optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10])
385
+ #megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
386
+ mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 30)
387
+
388
+ checkpointer = CheckPoint(checkpoint_dir, experiment_name)
389
+ model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
390
+ romatch.GLOBAL_STEP = global_step
391
+ grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
392
+ grad_clip_norm = 0.01
393
+ #megadense_benchmark.benchmark(model)
394
+ for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE):
395
+ mega_sampler = torch.utils.data.WeightedRandomSampler(
396
+ mega_ws, num_samples = batch_size * k, replacement=False
397
+ )
398
+ mega_dataloader = iter(
399
+ torch.utils.data.DataLoader(
400
+ megadepth_train,
401
+ batch_size = batch_size,
402
+ sampler = mega_sampler,
403
+ num_workers = 8,
404
+ )
405
+ )
406
+ train_k_steps(
407
+ n, k, mega_dataloader, model, depth_loss, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm,
408
+ )
409
+ checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP)
410
+ wandb.log(mega1500_benchmark.benchmark(model, model_name=experiment_name), step = romatch.GLOBAL_STEP)
411
+
412
+ def test_mega_8_scenes(model, name):
413
+ mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
414
+ scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
415
+ 'mega_8_scenes_0025_0.1_0.3.npz',
416
+ 'mega_8_scenes_0021_0.1_0.3.npz',
417
+ 'mega_8_scenes_0008_0.1_0.3.npz',
418
+ 'mega_8_scenes_0032_0.1_0.3.npz',
419
+ 'mega_8_scenes_1589_0.1_0.3.npz',
420
+ 'mega_8_scenes_0063_0.1_0.3.npz',
421
+ 'mega_8_scenes_0024_0.1_0.3.npz',
422
+ 'mega_8_scenes_0019_0.3_0.5.npz',
423
+ 'mega_8_scenes_0025_0.3_0.5.npz',
424
+ 'mega_8_scenes_0021_0.3_0.5.npz',
425
+ 'mega_8_scenes_0008_0.3_0.5.npz',
426
+ 'mega_8_scenes_0032_0.3_0.5.npz',
427
+ 'mega_8_scenes_1589_0.3_0.5.npz',
428
+ 'mega_8_scenes_0063_0.3_0.5.npz',
429
+ 'mega_8_scenes_0024_0.3_0.5.npz'])
430
+ mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
431
+ print(mega_8_scenes_results)
432
+ json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
433
+
434
+ def test_mega1500(model, name):
435
+ mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
436
+ mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
437
+ json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
438
+
439
+ def test_mega1500_poselib(model, name):
440
+ mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1)
441
+ mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
442
+ json.dump(mega1500_results, open(f"results/mega1500_poselib_{name}.json", "w"))
443
+
444
+ def test_mega_8_scenes_poselib(model, name):
445
+ mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1,
446
+ scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
447
+ 'mega_8_scenes_0025_0.1_0.3.npz',
448
+ 'mega_8_scenes_0021_0.1_0.3.npz',
449
+ 'mega_8_scenes_0008_0.1_0.3.npz',
450
+ 'mega_8_scenes_0032_0.1_0.3.npz',
451
+ 'mega_8_scenes_1589_0.1_0.3.npz',
452
+ 'mega_8_scenes_0063_0.1_0.3.npz',
453
+ 'mega_8_scenes_0024_0.1_0.3.npz',
454
+ 'mega_8_scenes_0019_0.3_0.5.npz',
455
+ 'mega_8_scenes_0025_0.3_0.5.npz',
456
+ 'mega_8_scenes_0021_0.3_0.5.npz',
457
+ 'mega_8_scenes_0008_0.3_0.5.npz',
458
+ 'mega_8_scenes_0032_0.3_0.5.npz',
459
+ 'mega_8_scenes_1589_0.3_0.5.npz',
460
+ 'mega_8_scenes_0063_0.3_0.5.npz',
461
+ 'mega_8_scenes_0024_0.3_0.5.npz'])
462
+ mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
463
+ json.dump(mega1500_results, open(f"results/mega_8_scenes_poselib_{name}.json", "w"))
464
+
465
+ def test_scannet_poselib(model, name):
466
+ scannet_benchmark = ScanNetPoselibBenchmark("data/scannet")
467
+ scannet_results = scannet_benchmark.benchmark(model)
468
+ json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
469
+
470
+ def test_scannet(model, name):
471
+ scannet_benchmark = ScanNetBenchmark("data/scannet")
472
+ scannet_results = scannet_benchmark.benchmark(model)
473
+ json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
474
+
475
+ if __name__ == "__main__":
476
+ os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
477
+ os.environ["OMP_NUM_THREADS"] = "16"
478
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
479
+ import romatch
480
+ parser = ArgumentParser()
481
+ parser.add_argument("--only_test", action='store_true')
482
+ parser.add_argument("--debug_mode", action='store_true')
483
+ parser.add_argument("--dont_log_wandb", action='store_true')
484
+ parser.add_argument("--train_resolution", default='medium')
485
+ parser.add_argument("--gpu_batch_size", default=8, type=int)
486
+ parser.add_argument("--wandb_entity", required = False)
487
+
488
+ args, _ = parser.parse_known_args()
489
+ romatch.DEBUG_MODE = args.debug_mode
490
+ if not args.only_test:
491
+ train(args)
492
+
493
+ experiment_name = "tiny_roma_v1_outdoor"#Path(__file__).stem
494
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
495
+ model = XFeatModel(freeze_xfeat=False, exact_softmax=False).to(device)
496
+ model.load_state_dict(torch.load(f"{experiment_name}.pth"))
497
+ test_mega1500_poselib(model, experiment_name)
498
+
submodules/RoMa/requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ einops
3
+ torchvision
4
+ opencv-python
5
+ kornia
6
+ albumentations
7
+ loguru
8
+ tqdm
9
+ matplotlib
10
+ h5py
11
+ wandb
12
+ timm
13
+ poselib
14
+ #xformers # Optional, used for memefficient attention
submodules/RoMa/romatch/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .models import roma_outdoor, tiny_roma_v1_outdoor, roma_indoor
3
+
4
+ DEBUG_MODE = False
5
+ RANK = int(os.environ.get('RANK', default = 0))
6
+ GLOBAL_STEP = 0
7
+ STEP_SIZE = 1
8
+ LOCAL_RANK = -1
submodules/RoMa/romatch/benchmarks/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .hpatches_sequences_homog_benchmark import HpatchesHomogBenchmark
2
+ from .scannet_benchmark import ScanNetBenchmark
3
+ from .megadepth_pose_estimation_benchmark import MegaDepthPoseEstimationBenchmark
4
+ from .megadepth_dense_benchmark import MegadepthDenseBenchmark
5
+ from .megadepth_pose_estimation_benchmark_poselib import Mega1500PoseLibBenchmark
6
+ #from .scannet_benchmark_poselib import ScanNetPoselibBenchmark
submodules/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+
4
+ import os
5
+
6
+ from tqdm import tqdm
7
+ from romatch.utils import pose_auc
8
+ import cv2
9
+
10
+
11
+ class HpatchesHomogBenchmark:
12
+ """Hpatches grid goes from [0,n-1] instead of [0.5,n-0.5]"""
13
+
14
+ def __init__(self, dataset_path) -> None:
15
+ seqs_dir = "hpatches-sequences-release"
16
+ self.seqs_path = os.path.join(dataset_path, seqs_dir)
17
+ self.seq_names = sorted(os.listdir(self.seqs_path))
18
+ # Ignore seqs is same as LoFTR.
19
+ self.ignore_seqs = set(
20
+ [
21
+ "i_contruction",
22
+ "i_crownnight",
23
+ "i_dc",
24
+ "i_pencils",
25
+ "i_whitebuilding",
26
+ "v_artisans",
27
+ "v_astronautis",
28
+ "v_talent",
29
+ ]
30
+ )
31
+
32
+ def convert_coordinates(self, im_A_coords, im_A_to_im_B, wq, hq, wsup, hsup):
33
+ offset = 0.5 # Hpatches assumes that the center of the top-left pixel is at [0,0] (I think)
34
+ im_A_coords = (
35
+ np.stack(
36
+ (
37
+ wq * (im_A_coords[..., 0] + 1) / 2,
38
+ hq * (im_A_coords[..., 1] + 1) / 2,
39
+ ),
40
+ axis=-1,
41
+ )
42
+ - offset
43
+ )
44
+ im_A_to_im_B = (
45
+ np.stack(
46
+ (
47
+ wsup * (im_A_to_im_B[..., 0] + 1) / 2,
48
+ hsup * (im_A_to_im_B[..., 1] + 1) / 2,
49
+ ),
50
+ axis=-1,
51
+ )
52
+ - offset
53
+ )
54
+ return im_A_coords, im_A_to_im_B
55
+
56
+ def benchmark(self, model, model_name = None):
57
+ n_matches = []
58
+ homog_dists = []
59
+ for seq_idx, seq_name in tqdm(
60
+ enumerate(self.seq_names), total=len(self.seq_names)
61
+ ):
62
+ im_A_path = os.path.join(self.seqs_path, seq_name, "1.ppm")
63
+ im_A = Image.open(im_A_path)
64
+ w1, h1 = im_A.size
65
+ for im_idx in range(2, 7):
66
+ im_B_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm")
67
+ im_B = Image.open(im_B_path)
68
+ w2, h2 = im_B.size
69
+ H = np.loadtxt(
70
+ os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
71
+ )
72
+ dense_matches, dense_certainty = model.match(
73
+ im_A_path, im_B_path
74
+ )
75
+ good_matches, _ = model.sample(dense_matches, dense_certainty, 5000)
76
+ pos_a, pos_b = self.convert_coordinates(
77
+ good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2
78
+ )
79
+ try:
80
+ H_pred, inliers = cv2.findHomography(
81
+ pos_a,
82
+ pos_b,
83
+ method = cv2.RANSAC,
84
+ confidence = 0.99999,
85
+ ransacReprojThreshold = 3 * min(w2, h2) / 480,
86
+ )
87
+ except:
88
+ H_pred = None
89
+ if H_pred is None:
90
+ H_pred = np.zeros((3, 3))
91
+ H_pred[2, 2] = 1.0
92
+ corners = np.array(
93
+ [[0, 0, 1], [0, h1 - 1, 1], [w1 - 1, 0, 1], [w1 - 1, h1 - 1, 1]]
94
+ )
95
+ real_warped_corners = np.dot(corners, np.transpose(H))
96
+ real_warped_corners = (
97
+ real_warped_corners[:, :2] / real_warped_corners[:, 2:]
98
+ )
99
+ warped_corners = np.dot(corners, np.transpose(H_pred))
100
+ warped_corners = warped_corners[:, :2] / warped_corners[:, 2:]
101
+ mean_dist = np.mean(
102
+ np.linalg.norm(real_warped_corners - warped_corners, axis=1)
103
+ ) / (min(w2, h2) / 480.0)
104
+ homog_dists.append(mean_dist)
105
+
106
+ n_matches = np.array(n_matches)
107
+ thresholds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
108
+ auc = pose_auc(np.array(homog_dists), thresholds)
109
+ return {
110
+ "hpatches_homog_auc_3": auc[2],
111
+ "hpatches_homog_auc_5": auc[4],
112
+ "hpatches_homog_auc_10": auc[9],
113
+ }
submodules/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import tqdm
4
+ from romatch.datasets import MegadepthBuilder
5
+ from romatch.utils import warp_kpts
6
+ from torch.utils.data import ConcatDataset
7
+ import romatch
8
+
9
+ class MegadepthDenseBenchmark:
10
+ def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000) -> None:
11
+ mega = MegadepthBuilder(data_root=data_root)
12
+ self.dataset = ConcatDataset(
13
+ mega.build_scenes(split="test_loftr", ht=h, wt=w)
14
+ ) # fixed resolution of 384,512
15
+ self.num_samples = num_samples
16
+
17
+ def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches):
18
+ b, h1, w1, d = dense_matches.shape
19
+ with torch.no_grad():
20
+ x1 = dense_matches[..., :2].reshape(b, h1 * w1, 2)
21
+ mask, x2 = warp_kpts(
22
+ x1.double(),
23
+ depth1.double(),
24
+ depth2.double(),
25
+ T_1to2.double(),
26
+ K1.double(),
27
+ K2.double(),
28
+ )
29
+ x2 = torch.stack(
30
+ (w1 * (x2[..., 0] + 1) / 2, h1 * (x2[..., 1] + 1) / 2), dim=-1
31
+ )
32
+ prob = mask.float().reshape(b, h1, w1)
33
+ x2_hat = dense_matches[..., 2:]
34
+ x2_hat = torch.stack(
35
+ (w1 * (x2_hat[..., 0] + 1) / 2, h1 * (x2_hat[..., 1] + 1) / 2), dim=-1
36
+ )
37
+ gd = (x2_hat - x2.reshape(b, h1, w1, 2)).norm(dim=-1)
38
+ gd = gd[prob == 1]
39
+ pck_1 = (gd < 1.0).float().mean()
40
+ pck_3 = (gd < 3.0).float().mean()
41
+ pck_5 = (gd < 5.0).float().mean()
42
+ return gd, pck_1, pck_3, pck_5, prob
43
+
44
+ def benchmark(self, model, batch_size=8):
45
+ model.train(False)
46
+ with torch.no_grad():
47
+ gd_tot = 0.0
48
+ pck_1_tot = 0.0
49
+ pck_3_tot = 0.0
50
+ pck_5_tot = 0.0
51
+ sampler = torch.utils.data.WeightedRandomSampler(
52
+ torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples
53
+ )
54
+ B = batch_size
55
+ dataloader = torch.utils.data.DataLoader(
56
+ self.dataset, batch_size=B, num_workers=batch_size, sampler=sampler
57
+ )
58
+ for idx, data in tqdm.tqdm(enumerate(dataloader), disable = romatch.RANK > 0):
59
+ im_A, im_B, depth1, depth2, T_1to2, K1, K2 = (
60
+ data["im_A"].cuda(),
61
+ data["im_B"].cuda(),
62
+ data["im_A_depth"].cuda(),
63
+ data["im_B_depth"].cuda(),
64
+ data["T_1to2"].cuda(),
65
+ data["K1"].cuda(),
66
+ data["K2"].cuda(),
67
+ )
68
+ matches, certainty = model.match(im_A, im_B, batched=True)
69
+ gd, pck_1, pck_3, pck_5, prob = self.geometric_dist(
70
+ depth1, depth2, T_1to2, K1, K2, matches
71
+ )
72
+ if romatch.DEBUG_MODE:
73
+ from romatch.utils.utils import tensor_to_pil
74
+ import torch.nn.functional as F
75
+ path = "vis"
76
+ H, W = model.get_output_resolution()
77
+ white_im = torch.ones((B,1,H,W),device="cuda")
78
+ im_B_transfer_rgb = F.grid_sample(
79
+ im_B.cuda(), matches[:,:,:W, 2:], mode="bilinear", align_corners=False
80
+ )
81
+ warp_im = im_B_transfer_rgb
82
+ c_b = certainty[:,None]#(certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None]
83
+ vis_im = c_b * warp_im + (1 - c_b) * white_im
84
+ for b in range(B):
85
+ import os
86
+ os.makedirs(f"{path}/{model.name}/{idx}_{b}_{H}_{W}",exist_ok=True)
87
+ tensor_to_pil(vis_im[b], unnormalize=True).save(
88
+ f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg")
89
+ tensor_to_pil(im_A[b].cuda(), unnormalize=True).save(
90
+ f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg")
91
+ tensor_to_pil(im_B[b].cuda(), unnormalize=True).save(
92
+ f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg")
93
+
94
+
95
+ gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = (
96
+ gd_tot + gd.mean(),
97
+ pck_1_tot + pck_1,
98
+ pck_3_tot + pck_3,
99
+ pck_5_tot + pck_5,
100
+ )
101
+ return {
102
+ "epe": gd_tot.item() / len(dataloader),
103
+ "mega_pck_1": pck_1_tot.item() / len(dataloader),
104
+ "mega_pck_3": pck_3_tot.item() / len(dataloader),
105
+ "mega_pck_5": pck_5_tot.item() / len(dataloader),
106
+ }
submodules/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from romatch.utils import *
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+ import torch.nn.functional as F
7
+ import romatch
8
+ import kornia.geometry.epipolar as kepi
9
+
10
+ class MegaDepthPoseEstimationBenchmark:
11
+ def __init__(self, data_root="data/megadepth", scene_names = None) -> None:
12
+ if scene_names is None:
13
+ self.scene_names = [
14
+ "0015_0.1_0.3.npz",
15
+ "0015_0.3_0.5.npz",
16
+ "0022_0.1_0.3.npz",
17
+ "0022_0.3_0.5.npz",
18
+ "0022_0.5_0.7.npz",
19
+ ]
20
+ else:
21
+ self.scene_names = scene_names
22
+ self.scenes = [
23
+ np.load(f"{data_root}/{scene}", allow_pickle=True)
24
+ for scene in self.scene_names
25
+ ]
26
+ self.data_root = data_root
27
+
28
+ def benchmark(self, model, model_name = None):
29
+ with torch.no_grad():
30
+ data_root = self.data_root
31
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
32
+ thresholds = [5, 10, 20]
33
+ for scene_ind in range(len(self.scenes)):
34
+ import os
35
+ scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
36
+ scene = self.scenes[scene_ind]
37
+ pairs = scene["pair_infos"]
38
+ intrinsics = scene["intrinsics"]
39
+ poses = scene["poses"]
40
+ im_paths = scene["image_paths"]
41
+ pair_inds = range(len(pairs))
42
+ for pairind in tqdm(pair_inds):
43
+ idx1, idx2 = pairs[pairind][0]
44
+ K1 = intrinsics[idx1].copy()
45
+ T1 = poses[idx1].copy()
46
+ R1, t1 = T1[:3, :3], T1[:3, 3]
47
+ K2 = intrinsics[idx2].copy()
48
+ T2 = poses[idx2].copy()
49
+ R2, t2 = T2[:3, :3], T2[:3, 3]
50
+ R, t = compute_relative_pose(R1, t1, R2, t2)
51
+ T1_to_2 = np.concatenate((R,t[:,None]), axis=-1)
52
+ im_A_path = f"{data_root}/{im_paths[idx1]}"
53
+ im_B_path = f"{data_root}/{im_paths[idx2]}"
54
+ dense_matches, dense_certainty = model.match(
55
+ im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy()
56
+ )
57
+ sparse_matches,_ = model.sample(
58
+ dense_matches, dense_certainty, 5_000
59
+ )
60
+
61
+ im_A = Image.open(im_A_path)
62
+ w1, h1 = im_A.size
63
+ im_B = Image.open(im_B_path)
64
+ w2, h2 = im_B.size
65
+ if True: # Note: we keep this true as it was used in DKM/RoMa papers. There is very little difference compared to setting to False.
66
+ scale1 = 1200 / max(w1, h1)
67
+ scale2 = 1200 / max(w2, h2)
68
+ w1, h1 = scale1 * w1, scale1 * h1
69
+ w2, h2 = scale2 * w2, scale2 * h2
70
+ K1, K2 = K1.copy(), K2.copy()
71
+ K1[:2] = K1[:2] * scale1
72
+ K2[:2] = K2[:2] * scale2
73
+
74
+ kpts1, kpts2 = model.to_pixel_coordinates(sparse_matches, h1, w1, h2, w2)
75
+ kpts1, kpts2 = kpts1.cpu().numpy(), kpts2.cpu().numpy()
76
+ for _ in range(5):
77
+ shuffling = np.random.permutation(np.arange(len(kpts1)))
78
+ kpts1 = kpts1[shuffling]
79
+ kpts2 = kpts2[shuffling]
80
+ try:
81
+ threshold = 0.5
82
+ norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
83
+ R_est, t_est, mask = estimate_pose(
84
+ kpts1,
85
+ kpts2,
86
+ K1,
87
+ K2,
88
+ norm_threshold,
89
+ conf=0.99999,
90
+ )
91
+ T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
92
+ e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
93
+ e_pose = max(e_t, e_R)
94
+ except Exception as e:
95
+ print(repr(e))
96
+ e_t, e_R = 90, 90
97
+ e_pose = max(e_t, e_R)
98
+ tot_e_t.append(e_t)
99
+ tot_e_R.append(e_R)
100
+ tot_e_pose.append(e_pose)
101
+ tot_e_pose = np.array(tot_e_pose)
102
+ auc = pose_auc(tot_e_pose, thresholds)
103
+ acc_5 = (tot_e_pose < 5).mean()
104
+ acc_10 = (tot_e_pose < 10).mean()
105
+ acc_15 = (tot_e_pose < 15).mean()
106
+ acc_20 = (tot_e_pose < 20).mean()
107
+ map_5 = acc_5
108
+ map_10 = np.mean([acc_5, acc_10])
109
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
110
+ print(f"{model_name} auc: {auc}")
111
+ return {
112
+ "auc_5": auc[0],
113
+ "auc_10": auc[1],
114
+ "auc_20": auc[2],
115
+ "map_5": map_5,
116
+ "map_10": map_10,
117
+ "map_20": map_20,
118
+ }