Spaces:
Running
on
Zero
Running
on
Zero
init
Browse files- .gitignore +23 -0
- README.md +5 -4
- arguments/__init__.py +187 -0
- command +33 -0
- gaussian_renderer/__init__.py +245 -0
- gaussian_renderer/network_gui.py +86 -0
- lpipsPyTorch/__init__.py +23 -0
- lpipsPyTorch/modules/lpips.py +44 -0
- lpipsPyTorch/modules/networks.py +96 -0
- lpipsPyTorch/modules/utils.py +30 -0
- requirements.txt +21 -0
- run_video.py +275 -0
- scene/__init__.py +105 -0
- scene/cameras.py +71 -0
- scene/colmap_loader.py +294 -0
- scene/dataset_readers.py +382 -0
- scene/gaussian_model.py +830 -0
- train_feat2gs.py +243 -0
- utils/camera_traj_config.py +655 -0
- utils/camera_utils.py +481 -0
- utils/dust3r_utils.py +432 -0
- utils/feat_utils.py +827 -0
- utils/general_utils.py +133 -0
- utils/graphics_utils.py +210 -0
- utils/image_utils.py +118 -0
- utils/loss_utils.py +247 -0
- utils/pose_utils.py +570 -0
- utils/sh_utils.py +118 -0
- utils/stepfun.py +403 -0
- utils/system_utils.py +28 -0
- utils/trajectories.py +243 -0
- utils/utils_poses/ATE/align_trajectory.py +80 -0
- utils/utils_poses/ATE/align_utils.py +144 -0
- utils/utils_poses/ATE/compute_trajectory_errors.py +89 -0
- utils/utils_poses/ATE/results_writer.py +75 -0
- utils/utils_poses/ATE/trajectory_utils.py +46 -0
- utils/utils_poses/ATE/transformations.py +1974 -0
- utils/utils_poses/align_traj.py +97 -0
- utils/utils_poses/comp_ate.py +74 -0
- utils/utils_poses/lie_group_helper.py +81 -0
- utils/utils_poses/relative_pose.py +20 -0
- utils/utils_poses/vis_cam_traj.py +138 -0
- utils/utils_poses/vis_pose_utils.py +270 -0
.gitignore
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/.idea/
|
2 |
+
/work_dirs*
|
3 |
+
.vscode/
|
4 |
+
/tmp
|
5 |
+
/data
|
6 |
+
# /checkpoints
|
7 |
+
*.so
|
8 |
+
*.patch
|
9 |
+
__pycache__/
|
10 |
+
*.egg-info/
|
11 |
+
/viz*
|
12 |
+
/submit*
|
13 |
+
build/
|
14 |
+
*.pyd
|
15 |
+
/cache*
|
16 |
+
*.stl
|
17 |
+
# *.pth
|
18 |
+
/venv/
|
19 |
+
.nk8s
|
20 |
+
*.mp4
|
21 |
+
.vs
|
22 |
+
/exp/
|
23 |
+
/dev/
|
README.md
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
---
|
2 |
title: Feat2GS
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
|
|
1 |
---
|
2 |
title: Feat2GS
|
3 |
+
emoji: ✨
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: green
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.20.1
|
8 |
+
python_version: 3.10.13
|
9 |
app_file: app.py
|
10 |
pinned: false
|
11 |
license: apache-2.0
|
arguments/__init__.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
from argparse import ArgumentParser, Namespace
|
13 |
+
import sys
|
14 |
+
import os
|
15 |
+
|
16 |
+
class GroupParams:
|
17 |
+
pass
|
18 |
+
|
19 |
+
class ParamGroup:
|
20 |
+
def __init__(self, parser: ArgumentParser, name : str, fill_none = False):
|
21 |
+
group = parser.add_argument_group(name)
|
22 |
+
for key, value in vars(self).items():
|
23 |
+
shorthand = False
|
24 |
+
if key.startswith("_"):
|
25 |
+
shorthand = True
|
26 |
+
key = key[1:]
|
27 |
+
t = type(value)
|
28 |
+
value = value if not fill_none else None
|
29 |
+
if shorthand:
|
30 |
+
if t == bool:
|
31 |
+
group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true")
|
32 |
+
else:
|
33 |
+
group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t)
|
34 |
+
else:
|
35 |
+
if t == bool:
|
36 |
+
group.add_argument("--" + key, default=value, action="store_true")
|
37 |
+
else:
|
38 |
+
group.add_argument("--" + key, default=value, type=t)
|
39 |
+
|
40 |
+
def extract(self, args):
|
41 |
+
group = GroupParams()
|
42 |
+
for arg in vars(args).items():
|
43 |
+
if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
|
44 |
+
setattr(group, arg[0], arg[1])
|
45 |
+
return group
|
46 |
+
|
47 |
+
class ModelParams(ParamGroup):
|
48 |
+
def __init__(self, parser, sentinel=False):
|
49 |
+
self.sh_degree = 3
|
50 |
+
self._source_path = ""
|
51 |
+
self._model_path = ""
|
52 |
+
self._images = "images"
|
53 |
+
self._resolution = -1
|
54 |
+
self._white_background = False
|
55 |
+
self.data_device = "cuda"
|
56 |
+
self.eval = False
|
57 |
+
self.feat_default_dim = {
|
58 |
+
'iuv': 3,
|
59 |
+
'iuvrgb': 6,
|
60 |
+
'mast3r': 1024,
|
61 |
+
'dust3r': 1024,
|
62 |
+
'dift': 1280,
|
63 |
+
'dino_b16': 768,
|
64 |
+
'dinov2_b14': 768,
|
65 |
+
'radio': 1280,
|
66 |
+
'clip_b16': 512,
|
67 |
+
'mae_b16': 768,
|
68 |
+
'midas_l16': 1024,
|
69 |
+
'sam_base': 768,
|
70 |
+
# 'dino16': 384,
|
71 |
+
# 'dinov2': 384,
|
72 |
+
# 'clip': 512,
|
73 |
+
# 'maskclip': 512,
|
74 |
+
# 'vit': 384,
|
75 |
+
# 'resnet50': 2048,
|
76 |
+
# 'midas': 768,
|
77 |
+
# 'mae': 1024,
|
78 |
+
}
|
79 |
+
self.gs_params_group = {
|
80 |
+
'G':{
|
81 |
+
'head': ['xyz', 'scaling', 'rotation', 'opacity'],
|
82 |
+
'opt':['f_dc', 'f_rest']
|
83 |
+
},
|
84 |
+
'T':{
|
85 |
+
'head': ['f_dc', 'f_rest'],
|
86 |
+
'opt':['xyz', 'scaling', 'rotation', 'opacity']
|
87 |
+
},
|
88 |
+
'A':{
|
89 |
+
'head': ['xyz', 'scaling', 'rotation', 'opacity', 'f_dc', 'f_rest'],
|
90 |
+
'opt':[]
|
91 |
+
},
|
92 |
+
'Gft':{
|
93 |
+
'head': ['xyz', 'scaling', 'rotation', 'opacity'],
|
94 |
+
'opt':['f_dc', 'f_rest', 'pc_feat']
|
95 |
+
},
|
96 |
+
'Tft':{
|
97 |
+
'head': ['f_dc', 'f_rest'],
|
98 |
+
'opt':['xyz', 'scaling', 'rotation', 'opacity', 'pc_feat']
|
99 |
+
},
|
100 |
+
'Aft':{
|
101 |
+
'head': ['xyz', 'scaling', 'rotation', 'opacity', 'f_dc', 'f_rest'],
|
102 |
+
'opt':['pc_feat']
|
103 |
+
},
|
104 |
+
}
|
105 |
+
super().__init__(parser, "Loading Parameters", sentinel)
|
106 |
+
|
107 |
+
def extract(self, args):
|
108 |
+
g = super().extract(args)
|
109 |
+
g.source_path = os.path.abspath(g.source_path)
|
110 |
+
return g
|
111 |
+
|
112 |
+
class PipelineParams(ParamGroup):
|
113 |
+
def __init__(self, parser):
|
114 |
+
self.convert_SHs_python = False
|
115 |
+
self.compute_cov3D_python = False
|
116 |
+
self.debug = False
|
117 |
+
super().__init__(parser, "Pipeline Parameters")
|
118 |
+
|
119 |
+
class DefualtOptimizationParams(ParamGroup):
|
120 |
+
def __init__(self, parser):
|
121 |
+
self.lr_multiplier = 1.
|
122 |
+
self.iterations = 30_000
|
123 |
+
self.position_lr_init = 0.00016 * self.lr_multiplier
|
124 |
+
self.position_lr_final = 0.0000016 * self.lr_multiplier
|
125 |
+
self.position_lr_delay_mult = 0.01
|
126 |
+
self.position_lr_max_steps = 30_000
|
127 |
+
self.feature_lr = 0.0025 * self.lr_multiplier
|
128 |
+
self.opacity_lr = 0.05 * self.lr_multiplier
|
129 |
+
self.scaling_lr = 0.005 * self.lr_multiplier
|
130 |
+
self.rotation_lr = 0.001 * self.lr_multiplier
|
131 |
+
self.percent_dense = 0.01
|
132 |
+
self.lambda_dssim = 0.2
|
133 |
+
self.densification_interval = 100
|
134 |
+
self.opacity_reset_interval = 3000
|
135 |
+
self.densify_from_iter = 500
|
136 |
+
self.densify_until_iter = 15_000
|
137 |
+
self.densify_grad_threshold = 0.0002
|
138 |
+
self.random_background = False
|
139 |
+
super().__init__(parser, "Optimization Parameters")
|
140 |
+
|
141 |
+
|
142 |
+
class OptimizationParams(ParamGroup):
|
143 |
+
def __init__(self, parser):
|
144 |
+
self.lr_multiplier = 0.1
|
145 |
+
self.iterations = 30_000
|
146 |
+
self.pose_lr_init = 0.0001 #0.0001
|
147 |
+
self.pose_lr_final = 0.000001 #0.0001
|
148 |
+
self.position_lr_init = 0.00016 * self.lr_multiplier #0.000001
|
149 |
+
self.position_lr_final = 0.0000016 * self.lr_multiplier #0.000001
|
150 |
+
self.position_lr_delay_mult = 0.01
|
151 |
+
self.position_lr_max_steps = 30_000
|
152 |
+
self.feature_lr = 0.0025 * self.lr_multiplier #0.001
|
153 |
+
self.feature_sh_lr = (0.0025/20.) * self.lr_multiplier #0.000001
|
154 |
+
self.opacity_lr = 0.05 * self.lr_multiplier #0.0001
|
155 |
+
self.scaling_lr = 0.005 * self.lr_multiplier # 0.001
|
156 |
+
self.rotation_lr = 0.001 * self.lr_multiplier # 0.00001
|
157 |
+
self.percent_dense = 0.01
|
158 |
+
self.lambda_dssim = 0.2
|
159 |
+
self.densification_interval = 100
|
160 |
+
self.opacity_reset_interval = 3000
|
161 |
+
# self.densify_from_iter = 500
|
162 |
+
# self.densify_until_iter = 15_000
|
163 |
+
# self.densify_grad_threshold = 0.0002
|
164 |
+
self.random_background = False
|
165 |
+
super().__init__(parser, "Optimization Parameters")
|
166 |
+
|
167 |
+
def get_combined_args(parser : ArgumentParser):
|
168 |
+
cmdlne_string = sys.argv[1:]
|
169 |
+
cfgfile_string = "Namespace()"
|
170 |
+
args_cmdline = parser.parse_args(cmdlne_string)
|
171 |
+
|
172 |
+
try:
|
173 |
+
cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
|
174 |
+
print("Looking for config file in", cfgfilepath)
|
175 |
+
with open(cfgfilepath) as cfg_file:
|
176 |
+
print("Config file found: {}".format(cfgfilepath))
|
177 |
+
cfgfile_string = cfg_file.read()
|
178 |
+
except TypeError:
|
179 |
+
print("Config file not found at")
|
180 |
+
pass
|
181 |
+
args_cfgfile = eval(cfgfile_string)
|
182 |
+
|
183 |
+
merged_dict = vars(args_cfgfile).copy()
|
184 |
+
for k,v in vars(args_cmdline).items():
|
185 |
+
if v != None:
|
186 |
+
merged_dict[k] = v
|
187 |
+
return Namespace(**merged_dict)
|
command
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
conda activate feat2gs
|
2 |
+
cd Feat2GS/
|
3 |
+
|
4 |
+
bash scripts/run_feat2gs_eval_parallel.sh
|
5 |
+
bash scripts/run_feat2gs_eval.sh
|
6 |
+
bash scripts/run_instantsplat_eval_parallel.sh
|
7 |
+
bash scripts/run_feat2gs_eval_dtu_parallel.sh
|
8 |
+
|
9 |
+
python video/generate_video.py
|
10 |
+
|
11 |
+
bash scripts/run_all_trajectories.sh
|
12 |
+
bash scripts/run_video_render.sh
|
13 |
+
bash scripts/run_video_render_instantsplat.sh
|
14 |
+
bash scripts/run_video_render_dtu.sh
|
15 |
+
|
16 |
+
tensorboard --logdir=/home/chenyue/output/Feat2gs/output/eval/ --port=7001
|
17 |
+
|
18 |
+
cd /home/chenyue/output/Feat2gs/output/eval/Tanks/Train/6_views/feat2gs-G/dust3r/
|
19 |
+
tensorboard --logdir_spec \
|
20 |
+
radio:radio,\
|
21 |
+
dust3r:dust3r,\
|
22 |
+
dino_b16:dino_b16,\
|
23 |
+
mast3r:mast3r,\
|
24 |
+
dift:dift,\
|
25 |
+
dinov2:dinov2_b14,\
|
26 |
+
clip:clip_b16,\
|
27 |
+
mae:mae_b16,\
|
28 |
+
midas:midas_l16,\
|
29 |
+
sam:sam_base,\
|
30 |
+
iuvrgb:iuvrgb \
|
31 |
+
--port 7002
|
32 |
+
|
33 |
+
CUDA_VISIBLE_DEVICES=7 gradio demo.py
|
gaussian_renderer/__init__.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import math
|
14 |
+
from scene.gaussian_model import GaussianModel
|
15 |
+
from utils.pose_utils import get_camera_from_tensor, quadmultiply
|
16 |
+
from utils.graphics_utils import depth_to_normal
|
17 |
+
|
18 |
+
|
19 |
+
### if use [diff-gaussian-rasterization](https://github.com/graphdeco-inria/diff-gaussian-rasterization)
|
20 |
+
|
21 |
+
# from diff_gaussian_rasterization import (
|
22 |
+
# GaussianRasterizationSettings,
|
23 |
+
# GaussianRasterizer,
|
24 |
+
# )
|
25 |
+
# from utils.sh_utils import eval_sh
|
26 |
+
|
27 |
+
# def render(
|
28 |
+
# viewpoint_camera,
|
29 |
+
# pc: GaussianModel,
|
30 |
+
# pipe,
|
31 |
+
# bg_color: torch.Tensor,
|
32 |
+
# scaling_modifier=1.0,
|
33 |
+
# override_color=None,
|
34 |
+
# camera_pose=None,
|
35 |
+
# ):
|
36 |
+
# """
|
37 |
+
# Render the scene.
|
38 |
+
|
39 |
+
# Background tensor (bg_color) must be on GPU!
|
40 |
+
# """
|
41 |
+
|
42 |
+
# # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
|
43 |
+
# screenspace_points = (
|
44 |
+
# torch.zeros_like(
|
45 |
+
# pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda"
|
46 |
+
# )
|
47 |
+
# + 0
|
48 |
+
# )
|
49 |
+
# try:
|
50 |
+
# screenspace_points.retain_grad()
|
51 |
+
# except:
|
52 |
+
# pass
|
53 |
+
|
54 |
+
# # Set up rasterization configuration
|
55 |
+
# tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
|
56 |
+
# tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
|
57 |
+
|
58 |
+
# # Set camera pose as identity. Then, we will transform the Gaussians around camera_pose
|
59 |
+
# w2c = torch.eye(4).cuda()
|
60 |
+
# projmatrix = (
|
61 |
+
# w2c.unsqueeze(0).bmm(viewpoint_camera.projection_matrix.unsqueeze(0))
|
62 |
+
# ).squeeze(0)
|
63 |
+
# camera_pos = w2c.inverse()[3, :3]
|
64 |
+
# raster_settings = GaussianRasterizationSettings(
|
65 |
+
# image_height=int(viewpoint_camera.image_height),
|
66 |
+
# image_width=int(viewpoint_camera.image_width),
|
67 |
+
# tanfovx=tanfovx,
|
68 |
+
# tanfovy=tanfovy,
|
69 |
+
# bg=bg_color,
|
70 |
+
# scale_modifier=scaling_modifier,
|
71 |
+
# # viewmatrix=viewpoint_camera.world_view_transform,
|
72 |
+
# # projmatrix=viewpoint_camera.full_proj_transform,
|
73 |
+
# viewmatrix=w2c,
|
74 |
+
# projmatrix=projmatrix,
|
75 |
+
# sh_degree=pc.active_sh_degree,
|
76 |
+
# # campos=viewpoint_camera.camera_center,
|
77 |
+
# campos=camera_pos,
|
78 |
+
# prefiltered=False,
|
79 |
+
# debug=pipe.debug,
|
80 |
+
# )
|
81 |
+
|
82 |
+
# rasterizer = GaussianRasterizer(raster_settings=raster_settings)
|
83 |
+
|
84 |
+
# # means3D = pc.get_xyz
|
85 |
+
# rel_w2c = get_camera_from_tensor(camera_pose)
|
86 |
+
# # Transform mean and rot of Gaussians to camera frame
|
87 |
+
# gaussians_xyz = pc._xyz.clone()
|
88 |
+
# gaussians_rot = pc._rotation.clone()
|
89 |
+
|
90 |
+
# xyz_ones = torch.ones(gaussians_xyz.shape[0], 1).cuda().float()
|
91 |
+
# xyz_homo = torch.cat((gaussians_xyz, xyz_ones), dim=1)
|
92 |
+
# gaussians_xyz_trans = (rel_w2c @ xyz_homo.T).T[:, :3]
|
93 |
+
# gaussians_rot_trans = quadmultiply(camera_pose[:4], gaussians_rot)
|
94 |
+
# means3D = gaussians_xyz_trans
|
95 |
+
# means2D = screenspace_points
|
96 |
+
# opacity = pc.get_opacity
|
97 |
+
|
98 |
+
# # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
|
99 |
+
# # scaling / rotation by the rasterizer.
|
100 |
+
# scales = None
|
101 |
+
# rotations = None
|
102 |
+
# cov3D_precomp = None
|
103 |
+
# if pipe.compute_cov3D_python:
|
104 |
+
# cov3D_precomp = pc.get_covariance(scaling_modifier)
|
105 |
+
# else:
|
106 |
+
# scales = pc.get_scaling
|
107 |
+
# rotations = gaussians_rot_trans # pc.get_rotation
|
108 |
+
|
109 |
+
# # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
|
110 |
+
# # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
|
111 |
+
# shs = None
|
112 |
+
# colors_precomp = None
|
113 |
+
# if override_color is None:
|
114 |
+
# if pipe.convert_SHs_python:
|
115 |
+
# shs_view = pc.get_features.transpose(1, 2).view(
|
116 |
+
# -1, 3, (pc.max_sh_degree + 1) ** 2
|
117 |
+
# )
|
118 |
+
# dir_pp = pc.get_xyz - viewpoint_camera.camera_center.repeat(
|
119 |
+
# pc.get_features.shape[0], 1
|
120 |
+
# )
|
121 |
+
# dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
|
122 |
+
# sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
|
123 |
+
# colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
|
124 |
+
# else:
|
125 |
+
# shs = pc.get_features
|
126 |
+
# else:
|
127 |
+
# colors_precomp = override_color
|
128 |
+
|
129 |
+
# # Rasterize visible Gaussians to image, obtain their radii (on screen).
|
130 |
+
# rendered_image, radii = rasterizer(
|
131 |
+
# means3D=means3D,
|
132 |
+
# means2D=means2D,
|
133 |
+
# shs=shs,
|
134 |
+
# colors_precomp=colors_precomp,
|
135 |
+
# opacities=opacity,
|
136 |
+
# scales=scales,
|
137 |
+
# rotations=rotations,
|
138 |
+
# cov3D_precomp=cov3D_precomp,
|
139 |
+
# )
|
140 |
+
|
141 |
+
# # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
|
142 |
+
# # They will be excluded from value updates used in the splitting criteria.
|
143 |
+
# return {
|
144 |
+
# "render": rendered_image,
|
145 |
+
# "viewspace_points": screenspace_points,
|
146 |
+
# "visibility_filter": radii > 0,
|
147 |
+
# "radii": radii,
|
148 |
+
# }
|
149 |
+
|
150 |
+
|
151 |
+
### if use [gsplat](https://github.com/nerfstudio-project/gsplat)
|
152 |
+
|
153 |
+
from gsplat import rasterization
|
154 |
+
|
155 |
+
def render_gsplat(
|
156 |
+
viewpoint_camera,
|
157 |
+
pc : GaussianModel,
|
158 |
+
pipe,
|
159 |
+
bg_color : torch.Tensor,
|
160 |
+
scaling_modifier = 1.0,
|
161 |
+
override_color = None,
|
162 |
+
camera_pose = None,
|
163 |
+
fov = None,
|
164 |
+
render_mode="RGB"):
|
165 |
+
"""
|
166 |
+
Render the scene.
|
167 |
+
|
168 |
+
Background tensor (bg_color) must be on GPU!
|
169 |
+
"""
|
170 |
+
if fov is None:
|
171 |
+
FoVx = viewpoint_camera.FoVx
|
172 |
+
FoVy = viewpoint_camera.FoVy
|
173 |
+
else:
|
174 |
+
FoVx = fov[0]
|
175 |
+
FoVy = fov[1]
|
176 |
+
tanfovx = math.tan(FoVx * 0.5)
|
177 |
+
tanfovy = math.tan(FoVy * 0.5)
|
178 |
+
focal_length_x = viewpoint_camera.image_width / (2 * tanfovx)
|
179 |
+
focal_length_y = viewpoint_camera.image_height / (2 * tanfovy)
|
180 |
+
K = torch.tensor(
|
181 |
+
[
|
182 |
+
[focal_length_x, 0, viewpoint_camera.image_width / 2.0],
|
183 |
+
[0, focal_length_y, viewpoint_camera.image_height / 2.0],
|
184 |
+
[0, 0, 1],
|
185 |
+
],
|
186 |
+
device="cuda",
|
187 |
+
)
|
188 |
+
|
189 |
+
means3D = pc.get_xyz
|
190 |
+
opacity = pc.get_opacity
|
191 |
+
scales = pc.get_scaling * scaling_modifier
|
192 |
+
rotations = pc.get_rotation
|
193 |
+
if override_color is not None:
|
194 |
+
colors = override_color # [N, 3]
|
195 |
+
sh_degree = None
|
196 |
+
else:
|
197 |
+
colors = pc.get_features # [N, K, 3]
|
198 |
+
sh_degree = pc.active_sh_degree
|
199 |
+
|
200 |
+
if camera_pose is None:
|
201 |
+
viewmat = viewpoint_camera.world_view_transform.transpose(0, 1) # [4, 4]
|
202 |
+
else:
|
203 |
+
viewmat = get_camera_from_tensor(camera_pose)
|
204 |
+
render_colors, render_alphas, info = rasterization(
|
205 |
+
means=means3D, # [N, 3]
|
206 |
+
quats=rotations, # [N, 4]
|
207 |
+
scales=scales, # [N, 3]
|
208 |
+
opacities=opacity.squeeze(-1), # [N,]
|
209 |
+
colors=colors,
|
210 |
+
viewmats=viewmat[None], # [1, 4, 4]
|
211 |
+
Ks=K[None], # [1, 3, 3]
|
212 |
+
backgrounds=bg_color[None],
|
213 |
+
width=int(viewpoint_camera.image_width),
|
214 |
+
height=int(viewpoint_camera.image_height),
|
215 |
+
packed=False,
|
216 |
+
sh_degree=sh_degree,
|
217 |
+
render_mode=render_mode,
|
218 |
+
)
|
219 |
+
|
220 |
+
if "D" in render_mode:
|
221 |
+
if "+" in render_mode:
|
222 |
+
depth_map = render_colors[..., -1:]
|
223 |
+
else:
|
224 |
+
depth_map = render_colors
|
225 |
+
|
226 |
+
normals_surf = depth_to_normal(
|
227 |
+
depth_map, torch.inverse(viewmat[None]), K[None])
|
228 |
+
normals_surf = normals_surf * (render_alphas).detach()
|
229 |
+
render_colors = torch.cat([render_colors, normals_surf], dim=-1)
|
230 |
+
|
231 |
+
# [1, H, W, 3] -> [3, H, W]
|
232 |
+
rendered_image = render_colors[0].permute(2, 0, 1)
|
233 |
+
|
234 |
+
radii = info["radii"].squeeze(0) # [N,]
|
235 |
+
try:
|
236 |
+
info["means2d"].retain_grad() # [1, N, 2]
|
237 |
+
except:
|
238 |
+
pass
|
239 |
+
|
240 |
+
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
|
241 |
+
# They will be excluded from value updates used in the splitting criteria.
|
242 |
+
return {"render": rendered_image,
|
243 |
+
"viewspace_points": info["means2d"],
|
244 |
+
"visibility_filter" : radii > 0,
|
245 |
+
"radii": radii}
|
gaussian_renderer/network_gui.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import traceback
|
14 |
+
import socket
|
15 |
+
import json
|
16 |
+
from scene.cameras import MiniCam
|
17 |
+
|
18 |
+
host = "127.0.0.1"
|
19 |
+
port = 6009
|
20 |
+
|
21 |
+
conn = None
|
22 |
+
addr = None
|
23 |
+
|
24 |
+
listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
25 |
+
|
26 |
+
def init(wish_host, wish_port):
|
27 |
+
global host, port, listener
|
28 |
+
host = wish_host
|
29 |
+
port = wish_port
|
30 |
+
listener.bind((host, port))
|
31 |
+
listener.listen()
|
32 |
+
listener.settimeout(0)
|
33 |
+
|
34 |
+
def try_connect():
|
35 |
+
global conn, addr, listener
|
36 |
+
try:
|
37 |
+
conn, addr = listener.accept()
|
38 |
+
print(f"\nConnected by {addr}")
|
39 |
+
conn.settimeout(None)
|
40 |
+
except Exception as inst:
|
41 |
+
pass
|
42 |
+
|
43 |
+
def read():
|
44 |
+
global conn
|
45 |
+
messageLength = conn.recv(4)
|
46 |
+
messageLength = int.from_bytes(messageLength, 'little')
|
47 |
+
message = conn.recv(messageLength)
|
48 |
+
return json.loads(message.decode("utf-8"))
|
49 |
+
|
50 |
+
def send(message_bytes, verify):
|
51 |
+
global conn
|
52 |
+
if message_bytes != None:
|
53 |
+
conn.sendall(message_bytes)
|
54 |
+
conn.sendall(len(verify).to_bytes(4, 'little'))
|
55 |
+
conn.sendall(bytes(verify, 'ascii'))
|
56 |
+
|
57 |
+
def receive():
|
58 |
+
message = read()
|
59 |
+
|
60 |
+
width = message["resolution_x"]
|
61 |
+
height = message["resolution_y"]
|
62 |
+
|
63 |
+
if width != 0 and height != 0:
|
64 |
+
try:
|
65 |
+
do_training = bool(message["train"])
|
66 |
+
fovy = message["fov_y"]
|
67 |
+
fovx = message["fov_x"]
|
68 |
+
znear = message["z_near"]
|
69 |
+
zfar = message["z_far"]
|
70 |
+
do_shs_python = bool(message["shs_python"])
|
71 |
+
do_rot_scale_python = bool(message["rot_scale_python"])
|
72 |
+
keep_alive = bool(message["keep_alive"])
|
73 |
+
scaling_modifier = message["scaling_modifier"]
|
74 |
+
world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda()
|
75 |
+
world_view_transform[:,1] = -world_view_transform[:,1]
|
76 |
+
world_view_transform[:,2] = -world_view_transform[:,2]
|
77 |
+
full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda()
|
78 |
+
full_proj_transform[:,1] = -full_proj_transform[:,1]
|
79 |
+
custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform)
|
80 |
+
except Exception as e:
|
81 |
+
print("")
|
82 |
+
traceback.print_exc()
|
83 |
+
raise e
|
84 |
+
return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier
|
85 |
+
else:
|
86 |
+
return None, None, None, None, None, None
|
lpipsPyTorch/__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from .modules.lpips import LPIPS
|
4 |
+
|
5 |
+
|
6 |
+
def lpips(x: torch.Tensor,
|
7 |
+
y: torch.Tensor,
|
8 |
+
net_type: str = 'alex',
|
9 |
+
version: str = '0.1',
|
10 |
+
return_spatial_map=False):
|
11 |
+
r"""Function that measures
|
12 |
+
Learned Perceptual Image Patch Similarity (LPIPS).
|
13 |
+
|
14 |
+
Arguments:
|
15 |
+
x, y (torch.Tensor): the input tensors to compare.
|
16 |
+
net_type (str): the network type to compare the features:
|
17 |
+
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
|
18 |
+
version (str): the version of LPIPS. Default: 0.1.
|
19 |
+
return_spatial_map (bool): whether to return the spatial map. Default: False.
|
20 |
+
"""
|
21 |
+
device = x.device
|
22 |
+
criterion = LPIPS(net_type, version).to(device)
|
23 |
+
return criterion(x, y, return_spatial_map)
|
lpipsPyTorch/modules/lpips.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .networks import get_network, LinLayers
|
5 |
+
from .utils import get_state_dict
|
6 |
+
|
7 |
+
|
8 |
+
class LPIPS(nn.Module):
|
9 |
+
r"""Creates a criterion that measures
|
10 |
+
Learned Perceptual Image Patch Similarity (LPIPS).
|
11 |
+
|
12 |
+
Arguments:
|
13 |
+
net_type (str): the network type to compare the features:
|
14 |
+
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
|
15 |
+
version (str): the version of LPIPS. Default: 0.1.
|
16 |
+
"""
|
17 |
+
def __init__(self, net_type: str = 'alex', version: str = '0.1'):
|
18 |
+
|
19 |
+
assert version in ['0.1'], 'v0.1 is only supported now'
|
20 |
+
|
21 |
+
super(LPIPS, self).__init__()
|
22 |
+
|
23 |
+
# pretrained network
|
24 |
+
self.net = get_network(net_type)
|
25 |
+
|
26 |
+
# linear layers
|
27 |
+
self.lin = LinLayers(self.net.n_channels_list)
|
28 |
+
self.lin.load_state_dict(get_state_dict(net_type, version))
|
29 |
+
|
30 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor, return_spatial_map=False):
|
31 |
+
feat_x, feat_y = self.net(x), self.net(y)
|
32 |
+
|
33 |
+
diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
|
34 |
+
res = [l(d) for d, l in zip(diff, self.lin)]
|
35 |
+
|
36 |
+
if return_spatial_map:
|
37 |
+
target_size = (x.shape[2], x.shape[3])
|
38 |
+
res_upsampled = [torch.nn.functional.interpolate(r, size=target_size, mode='bilinear', align_corners=False)
|
39 |
+
for r in res]
|
40 |
+
spatial_map = torch.sum(torch.cat(res_upsampled, 1), 1, keepdim=True)
|
41 |
+
return spatial_map
|
42 |
+
else:
|
43 |
+
res = [r.mean((2, 3), True) for r in res]
|
44 |
+
return torch.sum(torch.cat(res, 0), 0, True)
|
lpipsPyTorch/modules/networks.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Sequence
|
2 |
+
|
3 |
+
from itertools import chain
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torchvision import models
|
8 |
+
|
9 |
+
from .utils import normalize_activation
|
10 |
+
|
11 |
+
|
12 |
+
def get_network(net_type: str):
|
13 |
+
if net_type == 'alex':
|
14 |
+
return AlexNet()
|
15 |
+
elif net_type == 'squeeze':
|
16 |
+
return SqueezeNet()
|
17 |
+
elif net_type == 'vgg':
|
18 |
+
return VGG16()
|
19 |
+
else:
|
20 |
+
raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
|
21 |
+
|
22 |
+
|
23 |
+
class LinLayers(nn.ModuleList):
|
24 |
+
def __init__(self, n_channels_list: Sequence[int]):
|
25 |
+
super(LinLayers, self).__init__([
|
26 |
+
nn.Sequential(
|
27 |
+
nn.Identity(),
|
28 |
+
nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
|
29 |
+
) for nc in n_channels_list
|
30 |
+
])
|
31 |
+
|
32 |
+
for param in self.parameters():
|
33 |
+
param.requires_grad = False
|
34 |
+
|
35 |
+
|
36 |
+
class BaseNet(nn.Module):
|
37 |
+
def __init__(self):
|
38 |
+
super(BaseNet, self).__init__()
|
39 |
+
|
40 |
+
# register buffer
|
41 |
+
self.register_buffer(
|
42 |
+
'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
|
43 |
+
self.register_buffer(
|
44 |
+
'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
|
45 |
+
|
46 |
+
def set_requires_grad(self, state: bool):
|
47 |
+
for param in chain(self.parameters(), self.buffers()):
|
48 |
+
param.requires_grad = state
|
49 |
+
|
50 |
+
def z_score(self, x: torch.Tensor):
|
51 |
+
return (x - self.mean) / self.std
|
52 |
+
|
53 |
+
def forward(self, x: torch.Tensor):
|
54 |
+
x = self.z_score(x)
|
55 |
+
|
56 |
+
output = []
|
57 |
+
for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
|
58 |
+
x = layer(x)
|
59 |
+
if i in self.target_layers:
|
60 |
+
output.append(normalize_activation(x))
|
61 |
+
if len(output) == len(self.target_layers):
|
62 |
+
break
|
63 |
+
return output
|
64 |
+
|
65 |
+
|
66 |
+
class SqueezeNet(BaseNet):
|
67 |
+
def __init__(self):
|
68 |
+
super(SqueezeNet, self).__init__()
|
69 |
+
|
70 |
+
self.layers = models.squeezenet1_1(True).features
|
71 |
+
self.target_layers = [2, 5, 8, 10, 11, 12, 13]
|
72 |
+
self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
|
73 |
+
|
74 |
+
self.set_requires_grad(False)
|
75 |
+
|
76 |
+
|
77 |
+
class AlexNet(BaseNet):
|
78 |
+
def __init__(self):
|
79 |
+
super(AlexNet, self).__init__()
|
80 |
+
|
81 |
+
self.layers = models.alexnet(True).features
|
82 |
+
self.target_layers = [2, 5, 8, 10, 12]
|
83 |
+
self.n_channels_list = [64, 192, 384, 256, 256]
|
84 |
+
|
85 |
+
self.set_requires_grad(False)
|
86 |
+
|
87 |
+
|
88 |
+
class VGG16(BaseNet):
|
89 |
+
def __init__(self):
|
90 |
+
super(VGG16, self).__init__()
|
91 |
+
|
92 |
+
self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
|
93 |
+
self.target_layers = [4, 9, 16, 23, 30]
|
94 |
+
self.n_channels_list = [64, 128, 256, 512, 512]
|
95 |
+
|
96 |
+
self.set_requires_grad(False)
|
lpipsPyTorch/modules/utils.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def normalize_activation(x, eps=1e-10):
|
7 |
+
norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
|
8 |
+
return x / (norm_factor + eps)
|
9 |
+
|
10 |
+
|
11 |
+
def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
|
12 |
+
# build url
|
13 |
+
url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
|
14 |
+
+ f'master/lpips/weights/v{version}/{net_type}.pth'
|
15 |
+
|
16 |
+
# download
|
17 |
+
old_state_dict = torch.hub.load_state_dict_from_url(
|
18 |
+
url, progress=True,
|
19 |
+
map_location=None if torch.cuda.is_available() else torch.device('cpu')
|
20 |
+
)
|
21 |
+
|
22 |
+
# rename keys
|
23 |
+
new_state_dict = OrderedDict()
|
24 |
+
for key, val in old_state_dict.items():
|
25 |
+
new_key = key
|
26 |
+
new_key = new_key.replace('lin', '')
|
27 |
+
new_key = new_key.replace('model.', '')
|
28 |
+
new_state_dict[new_key] = val
|
29 |
+
|
30 |
+
return new_state_dict
|
requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.5.1
|
2 |
+
torchvision==0.20.1
|
3 |
+
roma
|
4 |
+
evo
|
5 |
+
gradio>=4,<5
|
6 |
+
matplotlib
|
7 |
+
tqdm
|
8 |
+
opencv-python
|
9 |
+
scipy
|
10 |
+
einops
|
11 |
+
trimesh
|
12 |
+
tensorboard
|
13 |
+
pyglet<2
|
14 |
+
imageio
|
15 |
+
gsplat
|
16 |
+
scikit-learn
|
17 |
+
hydra-submitit-launcher
|
18 |
+
huggingface-hub[torch]==0.24
|
19 |
+
plyfile
|
20 |
+
imageio[ffmpeg]
|
21 |
+
spaces
|
run_video.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import matplotlib
|
13 |
+
matplotlib.use('Agg')
|
14 |
+
|
15 |
+
import math
|
16 |
+
import copy
|
17 |
+
import torch
|
18 |
+
from scene import Scene
|
19 |
+
import os
|
20 |
+
from tqdm import tqdm
|
21 |
+
from gaussian_renderer import render_gsplat
|
22 |
+
from argparse import ArgumentParser
|
23 |
+
from arguments import ModelParams, PipelineParams, get_combined_args
|
24 |
+
from gaussian_renderer import GaussianModel
|
25 |
+
from utils.pose_utils import get_tensor_from_camera
|
26 |
+
import numpy as np
|
27 |
+
import imageio.v3 as iio
|
28 |
+
from utils.graphics_utils import resize_render, make_video_divisble
|
29 |
+
|
30 |
+
from utils.trajectories import (
|
31 |
+
get_arc_w2cs,
|
32 |
+
get_avg_w2c,
|
33 |
+
get_lemniscate_w2cs,
|
34 |
+
get_spiral_w2cs,
|
35 |
+
get_wander_w2cs,
|
36 |
+
get_lookat,
|
37 |
+
)
|
38 |
+
|
39 |
+
from utils.camera_utils import generate_interpolated_path, generate_ellipse_path
|
40 |
+
from utils.camera_traj_config import trajectory_configs
|
41 |
+
|
42 |
+
|
43 |
+
def save_interpolated_pose(model_path, iter, n_views):
|
44 |
+
|
45 |
+
org_pose = np.load(model_path + f"pose/pose_{iter}.npy")
|
46 |
+
# visualizer(org_pose, ["green" for _ in org_pose], model_path + "pose/poses_optimized.png")
|
47 |
+
n_interp = int(10 * 30 / n_views) # 10second, fps=30
|
48 |
+
all_inter_pose = []
|
49 |
+
for i in range(n_views-1):
|
50 |
+
tmp_inter_pose = generate_interpolated_path(poses=org_pose[i:i+2], n_interp=n_interp)
|
51 |
+
all_inter_pose.append(tmp_inter_pose)
|
52 |
+
all_inter_pose = np.array(all_inter_pose).reshape(-1, 3, 4)
|
53 |
+
|
54 |
+
inter_pose_list = []
|
55 |
+
for p in all_inter_pose:
|
56 |
+
tmp_view = np.eye(4)
|
57 |
+
tmp_view[:3, :3] = p[:3, :3]
|
58 |
+
tmp_view[:3, 3] = p[:3, 3]
|
59 |
+
inter_pose_list.append(tmp_view)
|
60 |
+
inter_pose = np.stack(inter_pose_list, 0)
|
61 |
+
return inter_pose
|
62 |
+
|
63 |
+
|
64 |
+
def save_ellipse_pose(model_path, iter, n_views):
|
65 |
+
|
66 |
+
org_pose = np.load(model_path + f"pose/pose_{iter}.npy")
|
67 |
+
# visualizer(org_pose, ["green" for _ in org_pose], model_path + "pose/poses_optimized.png")
|
68 |
+
n_interp = int(10 * 30 / n_views) * (n_views-1) # 10second, fps=30
|
69 |
+
all_inter_pose = generate_ellipse_path(org_pose, n_interp)
|
70 |
+
|
71 |
+
inter_pose_list = []
|
72 |
+
for p in all_inter_pose:
|
73 |
+
c2w = np.eye(4)
|
74 |
+
c2w[:3, :4] = p
|
75 |
+
inter_pose_list.append(np.linalg.inv(c2w))
|
76 |
+
inter_pose = np.stack(inter_pose_list, 0)
|
77 |
+
|
78 |
+
return inter_pose
|
79 |
+
|
80 |
+
def save_traj_pose(dataset, iter, args):
|
81 |
+
|
82 |
+
traj_up = trajectory_configs.get(args.dataset, {}).get(args.scene, {}).get('up', [-1, 1]) # Use -y axis in camera space as up vector
|
83 |
+
traj_params = trajectory_configs.get(args.dataset, {}).get(args.scene, {}).get(args.cam_traj, {})
|
84 |
+
|
85 |
+
# 1. Get training camera poses and calculate trajectory
|
86 |
+
org_pose = np.load(dataset.model_path + f"pose/pose_{iter}.npy")
|
87 |
+
train_w2cs = torch.from_numpy(org_pose).cuda()
|
88 |
+
|
89 |
+
# Calculate reference camera pose
|
90 |
+
avg_w2c = get_avg_w2c(train_w2cs)
|
91 |
+
train_c2ws = torch.linalg.inv(train_w2cs)
|
92 |
+
lookat = get_lookat(train_c2ws[:, :3, -1], train_c2ws[:, :3, 2])
|
93 |
+
# up = torch.tensor([0.0, 0.0, 1.0], device="cuda")
|
94 |
+
avg_c2w = torch.linalg.inv(avg_w2c)
|
95 |
+
up = traj_up[0] * (avg_c2w[:3, traj_up[1]])
|
96 |
+
# up = traj_up[0] * (avg_c2w[:3, 0]+avg_c2w[:3, 1])/2
|
97 |
+
|
98 |
+
# Temporarily load a camera to get intrinsic parameters
|
99 |
+
tmp_args = copy.deepcopy(args)
|
100 |
+
tmp_args.get_video = False
|
101 |
+
tmp_dataset = copy.deepcopy(dataset)
|
102 |
+
tmp_dataset.eval = False
|
103 |
+
with torch.no_grad():
|
104 |
+
temp_gaussians = GaussianModel(dataset.sh_degree)
|
105 |
+
temp_scene = Scene(tmp_dataset, temp_gaussians, load_iteration=iter, opt=tmp_args, shuffle=False)
|
106 |
+
|
107 |
+
view = temp_scene.getTrainCameras()[0]
|
108 |
+
tanfovx = math.tan(view.FoVx * 0.5)
|
109 |
+
tanfovy = math.tan(view.FoVy * 0.5)
|
110 |
+
focal_length_x = view.image_width / (2 * tanfovx)
|
111 |
+
focal_length_y = view.image_height / (2 * tanfovy)
|
112 |
+
|
113 |
+
K = torch.tensor([[focal_length_x, 0, view.image_width/2],
|
114 |
+
[0, focal_length_y, view.image_height/2],
|
115 |
+
[0, 0, 1]], device="cuda")
|
116 |
+
img_wh = (view.image_width, view.image_height)
|
117 |
+
|
118 |
+
del temp_scene # Release temporary scene
|
119 |
+
del temp_gaussians # Release temporary gaussians
|
120 |
+
|
121 |
+
# Calculate bounding sphere radius
|
122 |
+
rc_train_c2ws = torch.einsum("ij,njk->nik", torch.linalg.inv(avg_w2c), train_c2ws)
|
123 |
+
rc_pos = rc_train_c2ws[:, :3, -1]
|
124 |
+
rads = (rc_pos.amax(0) - rc_pos.amin(0)) * 1.25
|
125 |
+
|
126 |
+
num_frames = int(10 * 30 / args.n_views) * (args.n_views-1)
|
127 |
+
|
128 |
+
# Generate camera poses based on trajectory type
|
129 |
+
if args.cam_traj == 'arc':
|
130 |
+
w2cs = get_arc_w2cs(
|
131 |
+
ref_w2c=avg_w2c,
|
132 |
+
lookat=lookat,
|
133 |
+
up=up,
|
134 |
+
focal_length=K[0, 0].item(),
|
135 |
+
rads=rads,
|
136 |
+
num_frames=num_frames,
|
137 |
+
degree=traj_params.get('degree', 180.0)
|
138 |
+
)
|
139 |
+
elif args.cam_traj == 'spiral':
|
140 |
+
w2cs = get_spiral_w2cs(
|
141 |
+
ref_w2c=avg_w2c,
|
142 |
+
lookat=lookat,
|
143 |
+
up=up,
|
144 |
+
focal_length=K[0, 0].item(),
|
145 |
+
rads=rads,
|
146 |
+
num_frames=num_frames,
|
147 |
+
zrate=traj_params.get('zrate', 0.5),
|
148 |
+
rots=traj_params.get('rots', 1)
|
149 |
+
)
|
150 |
+
elif args.cam_traj == 'lemniscate':
|
151 |
+
w2cs = get_lemniscate_w2cs(
|
152 |
+
ref_w2c=avg_w2c,
|
153 |
+
lookat=lookat,
|
154 |
+
up=up,
|
155 |
+
focal_length=K[0, 0].item(),
|
156 |
+
rads=rads,
|
157 |
+
num_frames=num_frames,
|
158 |
+
degree=traj_params.get('degree', 45.0)
|
159 |
+
)
|
160 |
+
elif args.cam_traj == 'wander':
|
161 |
+
w2cs = get_wander_w2cs(
|
162 |
+
ref_w2c=avg_w2c,
|
163 |
+
focal_length=K[0, 0].item(),
|
164 |
+
num_frames=num_frames,
|
165 |
+
max_disp=traj_params.get('max_disp', 48.0)
|
166 |
+
)
|
167 |
+
else:
|
168 |
+
raise ValueError(f"Unknown camera trajectory: {args.cam_traj}")
|
169 |
+
|
170 |
+
return w2cs.cpu().numpy()
|
171 |
+
|
172 |
+
def render_sets(dataset: ModelParams, iteration: int, pipeline: PipelineParams, args):
|
173 |
+
if args.cam_traj in ['interpolated', 'ellipse']:
|
174 |
+
w2cs = globals().get(f'save_{args.cam_traj}_pose')(dataset.model_path, iteration, args.n_views)
|
175 |
+
else:
|
176 |
+
w2cs = save_traj_pose(dataset, iteration, args)
|
177 |
+
|
178 |
+
# visualizer(org_pose, ["green" for _ in org_pose], dataset.model_path + f"pose/poses_optimized.png")
|
179 |
+
# visualizer(w2cs, ["blue" for _ in w2cs], dataset.model_path + f"pose/poses_{args.cam_traj}.png")
|
180 |
+
np.save(dataset.model_path + f"pose/pose_{args.cam_traj}.npy", w2cs)
|
181 |
+
|
182 |
+
# 2. Load model and scene
|
183 |
+
with torch.no_grad():
|
184 |
+
gaussians = GaussianModel(dataset.sh_degree)
|
185 |
+
scene = Scene(dataset, gaussians, load_iteration=iteration, opt=args, shuffle=False)
|
186 |
+
|
187 |
+
# bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
|
188 |
+
bg_color = [1, 1, 1]
|
189 |
+
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
|
190 |
+
|
191 |
+
# 3. Rendering
|
192 |
+
# render_path = os.path.join(dataset.model_path, args.cam_traj, f"ours_{iteration}", "renders")
|
193 |
+
# if os.path.exists(render_path):
|
194 |
+
# shutil.rmtree(render_path)
|
195 |
+
# makedirs(render_path, exist_ok=True)
|
196 |
+
|
197 |
+
video = []
|
198 |
+
for idx, w2c in enumerate(tqdm(w2cs, desc="Rendering progress")):
|
199 |
+
camera_pose = get_tensor_from_camera(w2c.transpose(0, 1))
|
200 |
+
view = scene.getTrainCameras()[0] # Use parameters from the first camera as template
|
201 |
+
if args.resize:
|
202 |
+
view = resize_render(view)
|
203 |
+
|
204 |
+
rendering = render_gsplat(
|
205 |
+
view, gaussians, pipeline, background, camera_pose=camera_pose
|
206 |
+
)["render"]
|
207 |
+
|
208 |
+
# # Save single frame image
|
209 |
+
# torchvision.utils.save_image(
|
210 |
+
# rendering, os.path.join(render_path, "{0:05d}".format(idx) + ".png")
|
211 |
+
# )
|
212 |
+
|
213 |
+
# Add to video list
|
214 |
+
# img = (rendering.detach().cpu().numpy() * 255.0).astype(np.uint8)
|
215 |
+
img = (torch.clamp(rendering, 0, 1).detach().cpu().numpy() * 255.0).round().astype(np.uint8)
|
216 |
+
video.append(img)
|
217 |
+
|
218 |
+
video = np.stack(video, 0).transpose(0, 2, 3, 1)
|
219 |
+
|
220 |
+
# Save video
|
221 |
+
if args.get_video:
|
222 |
+
video_dir = os.path.join(dataset.model_path, 'videos')
|
223 |
+
os.makedirs(video_dir, exist_ok=True)
|
224 |
+
output_video_file = os.path.join(video_dir, f'{args.scene}_{args.n_views}_view_{args.cam_traj}.mp4')
|
225 |
+
# iio.imwrite(output_video_file, make_video_divisble(video), fps=30)
|
226 |
+
iio.imwrite(
|
227 |
+
output_video_file,
|
228 |
+
make_video_divisble(video),
|
229 |
+
fps=30,
|
230 |
+
codec='libx264',
|
231 |
+
quality=None,
|
232 |
+
output_params=[
|
233 |
+
'-crf', '28', # Good quality range between 18-28
|
234 |
+
'-preset', 'veryslow',
|
235 |
+
'-pix_fmt', 'yuv420p',
|
236 |
+
'-movflags', '+faststart'
|
237 |
+
]
|
238 |
+
)
|
239 |
+
|
240 |
+
# if args.get_video:
|
241 |
+
# image_folder = os.path.join(dataset.model_path, f'{args.cam_traj}/ours_{args.iteration}/renders')
|
242 |
+
# output_video_file = os.path.join(dataset.model_path, f'{args.scene}_{args.n_views}_view_{args.cam_traj}.mp4')
|
243 |
+
# images_to_video(image_folder, output_video_file, fps=30)
|
244 |
+
|
245 |
+
|
246 |
+
if __name__ == "__main__":
|
247 |
+
# Set up command line argument parser
|
248 |
+
parser = ArgumentParser(description="Testing script parameters")
|
249 |
+
model = ModelParams(parser, sentinel=True)
|
250 |
+
pipeline = PipelineParams(parser)
|
251 |
+
parser.add_argument("--iteration", default=-1, type=int)
|
252 |
+
parser.add_argument("--quiet", action="store_true")
|
253 |
+
parser.add_argument("--get_video", action="store_true")
|
254 |
+
parser.add_argument("--n_views", default=120, type=int)
|
255 |
+
parser.add_argument("--dataset", default=None, type=str)
|
256 |
+
parser.add_argument("--scene", default=None, type=str)
|
257 |
+
parser.add_argument("--cam_traj", default='arc', type=str,
|
258 |
+
choices=['arc', 'spiral', 'lemniscate', 'wander', 'interpolated', 'ellipse'],
|
259 |
+
help="Camera trajectory type")
|
260 |
+
parser.add_argument("--resize", action="store_true", default=True,
|
261 |
+
help="If True, resize rendering to square")
|
262 |
+
parser.add_argument("--feat_type", type=str, nargs='*', default=None,
|
263 |
+
help="Feature type(s). Multiple types can be specified for combination.")
|
264 |
+
parser.add_argument("--method", type=str, default='dust3r',
|
265 |
+
help="Method of Initialization, e.g., 'dust3r' or 'mast3r'")
|
266 |
+
|
267 |
+
args = get_combined_args(parser)
|
268 |
+
print("Rendering " + args.model_path)
|
269 |
+
|
270 |
+
render_sets(
|
271 |
+
model.extract(args),
|
272 |
+
args.iteration,
|
273 |
+
pipeline.extract(args),
|
274 |
+
args,
|
275 |
+
)
|
scene/__init__.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import os
|
13 |
+
import random
|
14 |
+
import json
|
15 |
+
from utils.system_utils import searchForMaxIteration
|
16 |
+
from scene.dataset_readers import sceneLoadTypeCallbacks
|
17 |
+
from scene.gaussian_model import GaussianModel, Feat2GaussianModel
|
18 |
+
from arguments import ModelParams
|
19 |
+
from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
|
20 |
+
|
21 |
+
class Scene:
|
22 |
+
|
23 |
+
gaussian : GaussianModel
|
24 |
+
|
25 |
+
def __init__(self, args : ModelParams, gaussian : GaussianModel, load_iteration=None, opt=None, shuffle=True, resolution_scales=[1.0]):
|
26 |
+
"""b
|
27 |
+
:param path: Path to colmap scene main folder.
|
28 |
+
"""
|
29 |
+
self.model_path = args.model_path
|
30 |
+
self.loaded_iter = None
|
31 |
+
self.gaussians = gaussian
|
32 |
+
|
33 |
+
if load_iteration:
|
34 |
+
if load_iteration == -1:
|
35 |
+
self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
|
36 |
+
else:
|
37 |
+
self.loaded_iter = load_iteration
|
38 |
+
print("Loading trained model at iteration {}".format(self.loaded_iter))
|
39 |
+
|
40 |
+
self.train_cameras = {}
|
41 |
+
self.test_cameras = {}
|
42 |
+
# self.render_cameras = {}
|
43 |
+
|
44 |
+
if os.path.exists(os.path.join(args.source_path, "sparse")):
|
45 |
+
scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval, args, opt)
|
46 |
+
elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
|
47 |
+
print("Found transforms_train.json file, assuming Blender data set!")
|
48 |
+
scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
|
49 |
+
else:
|
50 |
+
assert False, "Could not recognize scene type!"
|
51 |
+
|
52 |
+
if not self.loaded_iter:
|
53 |
+
with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
|
54 |
+
dest_file.write(src_file.read())
|
55 |
+
json_cams = []
|
56 |
+
camlist = []
|
57 |
+
if scene_info.test_cameras:
|
58 |
+
camlist.extend(scene_info.test_cameras)
|
59 |
+
if scene_info.train_cameras:
|
60 |
+
camlist.extend(scene_info.train_cameras)
|
61 |
+
# if scene_info.render_cameras:
|
62 |
+
# camlist.extend(scene_info.render_cameras)
|
63 |
+
for id, cam in enumerate(camlist):
|
64 |
+
json_cams.append(camera_to_JSON(id, cam))
|
65 |
+
with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
|
66 |
+
json.dump(json_cams, file)
|
67 |
+
|
68 |
+
if shuffle:
|
69 |
+
random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
|
70 |
+
random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
|
71 |
+
|
72 |
+
self.cameras_extent = scene_info.nerf_normalization["radius"]
|
73 |
+
|
74 |
+
for resolution_scale in resolution_scales:
|
75 |
+
print("Loading Training Cameras")
|
76 |
+
self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
|
77 |
+
print('train_camera_num: ', len(self.train_cameras[resolution_scale]))
|
78 |
+
print("Loading Test Cameras")
|
79 |
+
self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
|
80 |
+
print('test_camera_num: ', len(self.test_cameras[resolution_scale]))
|
81 |
+
# print("Loading Render Cameras")
|
82 |
+
# self.render_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.render_cameras, resolution_scale, args)
|
83 |
+
# print('render_camera_num: ', len(self.render_cameras[resolution_scale]))
|
84 |
+
|
85 |
+
if self.loaded_iter:
|
86 |
+
self.gaussians.load_ply(os.path.join(self.model_path,
|
87 |
+
"point_cloud",
|
88 |
+
"iteration_" + str(self.loaded_iter),
|
89 |
+
"point_cloud.ply"))
|
90 |
+
else:
|
91 |
+
self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
|
92 |
+
self.gaussians.init_RT_seq(self.train_cameras)
|
93 |
+
|
94 |
+
def save(self, iteration):
|
95 |
+
point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
|
96 |
+
self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
|
97 |
+
|
98 |
+
def getTrainCameras(self, scale=1.0):
|
99 |
+
return self.train_cameras[scale]
|
100 |
+
|
101 |
+
def getTestCameras(self, scale=1.0):
|
102 |
+
return self.test_cameras[scale]
|
103 |
+
|
104 |
+
# def getRenderCameras(self, scale=1.0):
|
105 |
+
# return self.render_cameras[scale]
|
scene/cameras.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
import numpy as np
|
15 |
+
from utils.graphics_utils import getWorld2View2, getProjectionMatrix
|
16 |
+
|
17 |
+
class Camera(nn.Module):
|
18 |
+
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
|
19 |
+
image_name, uid,
|
20 |
+
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
|
21 |
+
):
|
22 |
+
super(Camera, self).__init__()
|
23 |
+
|
24 |
+
self.uid = uid
|
25 |
+
self.colmap_id = colmap_id
|
26 |
+
self.R = R
|
27 |
+
self.T = T
|
28 |
+
self.FoVx = FoVx
|
29 |
+
self.FoVy = FoVy
|
30 |
+
self.image_name = image_name
|
31 |
+
|
32 |
+
try:
|
33 |
+
self.data_device = torch.device(data_device)
|
34 |
+
except Exception as e:
|
35 |
+
print(e)
|
36 |
+
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
|
37 |
+
self.data_device = torch.device("cuda")
|
38 |
+
|
39 |
+
self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
|
40 |
+
self.image_width = self.original_image.shape[2]
|
41 |
+
self.image_height = self.original_image.shape[1]
|
42 |
+
|
43 |
+
if gt_alpha_mask is not None:
|
44 |
+
self.original_image *= gt_alpha_mask.to(self.data_device)
|
45 |
+
else:
|
46 |
+
self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
|
47 |
+
|
48 |
+
self.zfar = 100.0
|
49 |
+
self.znear = 0.01
|
50 |
+
|
51 |
+
self.trans = trans
|
52 |
+
self.scale = scale
|
53 |
+
|
54 |
+
self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
|
55 |
+
self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
|
56 |
+
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
|
57 |
+
self.camera_center = self.world_view_transform.inverse()[3, :3]
|
58 |
+
|
59 |
+
class MiniCam:
|
60 |
+
def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
|
61 |
+
self.image_width = width
|
62 |
+
self.image_height = height
|
63 |
+
self.FoVy = fovy
|
64 |
+
self.FoVx = fovx
|
65 |
+
self.znear = znear
|
66 |
+
self.zfar = zfar
|
67 |
+
self.world_view_transform = world_view_transform
|
68 |
+
self.full_proj_transform = full_proj_transform
|
69 |
+
view_inv = torch.inverse(self.world_view_transform)
|
70 |
+
self.camera_center = view_inv[3][:3]
|
71 |
+
|
scene/colmap_loader.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import collections
|
14 |
+
import struct
|
15 |
+
|
16 |
+
CameraModel = collections.namedtuple(
|
17 |
+
"CameraModel", ["model_id", "model_name", "num_params"])
|
18 |
+
Camera = collections.namedtuple(
|
19 |
+
"Camera", ["id", "model", "width", "height", "params"])
|
20 |
+
BaseImage = collections.namedtuple(
|
21 |
+
"Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
|
22 |
+
Point3D = collections.namedtuple(
|
23 |
+
"Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
|
24 |
+
CAMERA_MODELS = {
|
25 |
+
CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
|
26 |
+
CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
|
27 |
+
CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
|
28 |
+
CameraModel(model_id=3, model_name="RADIAL", num_params=5),
|
29 |
+
CameraModel(model_id=4, model_name="OPENCV", num_params=8),
|
30 |
+
CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
|
31 |
+
CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
|
32 |
+
CameraModel(model_id=7, model_name="FOV", num_params=5),
|
33 |
+
CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
|
34 |
+
CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
|
35 |
+
CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
|
36 |
+
}
|
37 |
+
CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
|
38 |
+
for camera_model in CAMERA_MODELS])
|
39 |
+
CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
|
40 |
+
for camera_model in CAMERA_MODELS])
|
41 |
+
|
42 |
+
|
43 |
+
def qvec2rotmat(qvec):
|
44 |
+
return np.array([
|
45 |
+
[1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
|
46 |
+
2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
|
47 |
+
2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
|
48 |
+
[2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
|
49 |
+
1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
|
50 |
+
2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
|
51 |
+
[2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
|
52 |
+
2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
|
53 |
+
1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
|
54 |
+
|
55 |
+
def rotmat2qvec(R):
|
56 |
+
Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
|
57 |
+
K = np.array([
|
58 |
+
[Rxx - Ryy - Rzz, 0, 0, 0],
|
59 |
+
[Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
|
60 |
+
[Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
|
61 |
+
[Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
|
62 |
+
eigvals, eigvecs = np.linalg.eigh(K)
|
63 |
+
qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
|
64 |
+
if qvec[0] < 0:
|
65 |
+
qvec *= -1
|
66 |
+
return qvec
|
67 |
+
|
68 |
+
class Image(BaseImage):
|
69 |
+
def qvec2rotmat(self):
|
70 |
+
return qvec2rotmat(self.qvec)
|
71 |
+
|
72 |
+
def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
|
73 |
+
"""Read and unpack the next bytes from a binary file.
|
74 |
+
:param fid:
|
75 |
+
:param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
|
76 |
+
:param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
|
77 |
+
:param endian_character: Any of {@, =, <, >, !}
|
78 |
+
:return: Tuple of read and unpacked values.
|
79 |
+
"""
|
80 |
+
data = fid.read(num_bytes)
|
81 |
+
return struct.unpack(endian_character + format_char_sequence, data)
|
82 |
+
|
83 |
+
def read_points3D_text(path):
|
84 |
+
"""
|
85 |
+
see: src/base/reconstruction.cc
|
86 |
+
void Reconstruction::ReadPoints3DText(const std::string& path)
|
87 |
+
void Reconstruction::WritePoints3DText(const std::string& path)
|
88 |
+
"""
|
89 |
+
xyzs = None
|
90 |
+
rgbs = None
|
91 |
+
errors = None
|
92 |
+
num_points = 0
|
93 |
+
with open(path, "r") as fid:
|
94 |
+
while True:
|
95 |
+
line = fid.readline()
|
96 |
+
if not line:
|
97 |
+
break
|
98 |
+
line = line.strip()
|
99 |
+
if len(line) > 0 and line[0] != "#":
|
100 |
+
num_points += 1
|
101 |
+
|
102 |
+
|
103 |
+
xyzs = np.empty((num_points, 3))
|
104 |
+
rgbs = np.empty((num_points, 3))
|
105 |
+
errors = np.empty((num_points, 1))
|
106 |
+
count = 0
|
107 |
+
with open(path, "r") as fid:
|
108 |
+
while True:
|
109 |
+
line = fid.readline()
|
110 |
+
if not line:
|
111 |
+
break
|
112 |
+
line = line.strip()
|
113 |
+
if len(line) > 0 and line[0] != "#":
|
114 |
+
elems = line.split()
|
115 |
+
xyz = np.array(tuple(map(float, elems[1:4])))
|
116 |
+
rgb = np.array(tuple(map(int, elems[4:7])))
|
117 |
+
error = np.array(float(elems[7]))
|
118 |
+
xyzs[count] = xyz
|
119 |
+
rgbs[count] = rgb
|
120 |
+
errors[count] = error
|
121 |
+
count += 1
|
122 |
+
|
123 |
+
return xyzs, rgbs, errors
|
124 |
+
|
125 |
+
def read_points3D_binary(path_to_model_file):
|
126 |
+
"""
|
127 |
+
see: src/base/reconstruction.cc
|
128 |
+
void Reconstruction::ReadPoints3DBinary(const std::string& path)
|
129 |
+
void Reconstruction::WritePoints3DBinary(const std::string& path)
|
130 |
+
"""
|
131 |
+
|
132 |
+
|
133 |
+
with open(path_to_model_file, "rb") as fid:
|
134 |
+
num_points = read_next_bytes(fid, 8, "Q")[0]
|
135 |
+
|
136 |
+
xyzs = np.empty((num_points, 3))
|
137 |
+
rgbs = np.empty((num_points, 3))
|
138 |
+
errors = np.empty((num_points, 1))
|
139 |
+
|
140 |
+
for p_id in range(num_points):
|
141 |
+
binary_point_line_properties = read_next_bytes(
|
142 |
+
fid, num_bytes=43, format_char_sequence="QdddBBBd")
|
143 |
+
xyz = np.array(binary_point_line_properties[1:4])
|
144 |
+
rgb = np.array(binary_point_line_properties[4:7])
|
145 |
+
error = np.array(binary_point_line_properties[7])
|
146 |
+
track_length = read_next_bytes(
|
147 |
+
fid, num_bytes=8, format_char_sequence="Q")[0]
|
148 |
+
track_elems = read_next_bytes(
|
149 |
+
fid, num_bytes=8*track_length,
|
150 |
+
format_char_sequence="ii"*track_length)
|
151 |
+
xyzs[p_id] = xyz
|
152 |
+
rgbs[p_id] = rgb
|
153 |
+
errors[p_id] = error
|
154 |
+
return xyzs, rgbs, errors
|
155 |
+
|
156 |
+
def read_intrinsics_text(path):
|
157 |
+
"""
|
158 |
+
Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
|
159 |
+
"""
|
160 |
+
cameras = {}
|
161 |
+
with open(path, "r") as fid:
|
162 |
+
while True:
|
163 |
+
line = fid.readline()
|
164 |
+
if not line:
|
165 |
+
break
|
166 |
+
line = line.strip()
|
167 |
+
if len(line) > 0 and line[0] != "#":
|
168 |
+
elems = line.split()
|
169 |
+
camera_id = int(elems[0])
|
170 |
+
model = elems[1]
|
171 |
+
assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE"
|
172 |
+
width = int(elems[2])
|
173 |
+
height = int(elems[3])
|
174 |
+
params = np.array(tuple(map(float, elems[4:])))
|
175 |
+
cameras[camera_id] = Camera(id=camera_id, model=model,
|
176 |
+
width=width, height=height,
|
177 |
+
params=params)
|
178 |
+
return cameras
|
179 |
+
|
180 |
+
def read_extrinsics_binary(path_to_model_file):
|
181 |
+
"""
|
182 |
+
see: src/base/reconstruction.cc
|
183 |
+
void Reconstruction::ReadImagesBinary(const std::string& path)
|
184 |
+
void Reconstruction::WriteImagesBinary(const std::string& path)
|
185 |
+
"""
|
186 |
+
images = {}
|
187 |
+
with open(path_to_model_file, "rb") as fid:
|
188 |
+
num_reg_images = read_next_bytes(fid, 8, "Q")[0]
|
189 |
+
for _ in range(num_reg_images):
|
190 |
+
binary_image_properties = read_next_bytes(
|
191 |
+
fid, num_bytes=64, format_char_sequence="idddddddi")
|
192 |
+
image_id = binary_image_properties[0]
|
193 |
+
qvec = np.array(binary_image_properties[1:5])
|
194 |
+
tvec = np.array(binary_image_properties[5:8])
|
195 |
+
camera_id = binary_image_properties[8]
|
196 |
+
image_name = ""
|
197 |
+
current_char = read_next_bytes(fid, 1, "c")[0]
|
198 |
+
while current_char != b"\x00": # look for the ASCII 0 entry
|
199 |
+
image_name += current_char.decode("utf-8")
|
200 |
+
current_char = read_next_bytes(fid, 1, "c")[0]
|
201 |
+
num_points2D = read_next_bytes(fid, num_bytes=8,
|
202 |
+
format_char_sequence="Q")[0]
|
203 |
+
x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
|
204 |
+
format_char_sequence="ddq"*num_points2D)
|
205 |
+
xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
|
206 |
+
tuple(map(float, x_y_id_s[1::3]))])
|
207 |
+
point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
|
208 |
+
images[image_id] = Image(
|
209 |
+
id=image_id, qvec=qvec, tvec=tvec,
|
210 |
+
camera_id=camera_id, name=image_name,
|
211 |
+
xys=xys, point3D_ids=point3D_ids)
|
212 |
+
return images
|
213 |
+
|
214 |
+
|
215 |
+
def read_intrinsics_binary(path_to_model_file):
|
216 |
+
"""
|
217 |
+
see: src/base/reconstruction.cc
|
218 |
+
void Reconstruction::WriteCamerasBinary(const std::string& path)
|
219 |
+
void Reconstruction::ReadCamerasBinary(const std::string& path)
|
220 |
+
"""
|
221 |
+
cameras = {}
|
222 |
+
with open(path_to_model_file, "rb") as fid:
|
223 |
+
num_cameras = read_next_bytes(fid, 8, "Q")[0]
|
224 |
+
for _ in range(num_cameras):
|
225 |
+
camera_properties = read_next_bytes(
|
226 |
+
fid, num_bytes=24, format_char_sequence="iiQQ")
|
227 |
+
camera_id = camera_properties[0]
|
228 |
+
model_id = camera_properties[1]
|
229 |
+
model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
|
230 |
+
width = camera_properties[2]
|
231 |
+
height = camera_properties[3]
|
232 |
+
num_params = CAMERA_MODEL_IDS[model_id].num_params
|
233 |
+
params = read_next_bytes(fid, num_bytes=8*num_params,
|
234 |
+
format_char_sequence="d"*num_params)
|
235 |
+
cameras[camera_id] = Camera(id=camera_id,
|
236 |
+
model=model_name,
|
237 |
+
width=width,
|
238 |
+
height=height,
|
239 |
+
params=np.array(params))
|
240 |
+
assert len(cameras) == num_cameras
|
241 |
+
return cameras
|
242 |
+
|
243 |
+
|
244 |
+
def read_extrinsics_text(path):
|
245 |
+
"""
|
246 |
+
Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
|
247 |
+
"""
|
248 |
+
images = {}
|
249 |
+
with open(path, "r") as fid:
|
250 |
+
while True:
|
251 |
+
line = fid.readline()
|
252 |
+
if not line:
|
253 |
+
break
|
254 |
+
line = line.strip()
|
255 |
+
if len(line) > 0 and line[0] != "#":
|
256 |
+
elems = line.split()
|
257 |
+
image_id = int(elems[0])
|
258 |
+
qvec = np.array(tuple(map(float, elems[1:5])))
|
259 |
+
tvec = np.array(tuple(map(float, elems[5:8])))
|
260 |
+
camera_id = int(elems[8])
|
261 |
+
image_name = elems[9]
|
262 |
+
elems = fid.readline().split()
|
263 |
+
xys = np.column_stack([tuple(map(float, elems[0::3])),
|
264 |
+
tuple(map(float, elems[1::3]))])
|
265 |
+
point3D_ids = np.array(tuple(map(int, elems[2::3])))
|
266 |
+
images[image_id] = Image(
|
267 |
+
id=image_id, qvec=qvec, tvec=tvec,
|
268 |
+
camera_id=camera_id, name=image_name,
|
269 |
+
xys=xys, point3D_ids=point3D_ids)
|
270 |
+
return images
|
271 |
+
|
272 |
+
|
273 |
+
def read_colmap_bin_array(path):
|
274 |
+
"""
|
275 |
+
Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py
|
276 |
+
|
277 |
+
:param path: path to the colmap binary file.
|
278 |
+
:return: nd array with the floating point values in the value
|
279 |
+
"""
|
280 |
+
with open(path, "rb") as fid:
|
281 |
+
width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1,
|
282 |
+
usecols=(0, 1, 2), dtype=int)
|
283 |
+
fid.seek(0)
|
284 |
+
num_delimiter = 0
|
285 |
+
byte = fid.read(1)
|
286 |
+
while True:
|
287 |
+
if byte == b"&":
|
288 |
+
num_delimiter += 1
|
289 |
+
if num_delimiter >= 3:
|
290 |
+
break
|
291 |
+
byte = fid.read(1)
|
292 |
+
array = np.fromfile(fid, np.float32)
|
293 |
+
array = array.reshape((width, height, channels), order="F")
|
294 |
+
return np.transpose(array, (1, 0, 2)).squeeze()
|
scene/dataset_readers.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import os
|
13 |
+
import sys
|
14 |
+
from PIL import Image
|
15 |
+
from typing import NamedTuple
|
16 |
+
from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \
|
17 |
+
read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text
|
18 |
+
from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
|
19 |
+
import numpy as np
|
20 |
+
import json
|
21 |
+
from pathlib import Path
|
22 |
+
from plyfile import PlyData, PlyElement
|
23 |
+
from utils.sh_utils import SH2RGB
|
24 |
+
from scene.gaussian_model import BasicPointCloud
|
25 |
+
# from utils.camera_utils import generate_ellipse_path_from_camera_infos
|
26 |
+
|
27 |
+
class CameraInfo(NamedTuple):
|
28 |
+
uid: int
|
29 |
+
R: np.array
|
30 |
+
T: np.array
|
31 |
+
FovY: np.array
|
32 |
+
FovX: np.array
|
33 |
+
image: np.array
|
34 |
+
image_path: str
|
35 |
+
image_name: str
|
36 |
+
width: int
|
37 |
+
height: int
|
38 |
+
|
39 |
+
|
40 |
+
class SceneInfo(NamedTuple):
|
41 |
+
point_cloud: BasicPointCloud
|
42 |
+
train_cameras: list
|
43 |
+
test_cameras: list
|
44 |
+
# render_cameras: list
|
45 |
+
nerf_normalization: dict
|
46 |
+
ply_path: str
|
47 |
+
train_poses: list
|
48 |
+
test_poses: list
|
49 |
+
|
50 |
+
def getNerfppNorm(cam_info):
|
51 |
+
def get_center_and_diag(cam_centers):
|
52 |
+
cam_centers = np.hstack(cam_centers)
|
53 |
+
avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
|
54 |
+
center = avg_cam_center
|
55 |
+
dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
|
56 |
+
diagonal = np.max(dist)
|
57 |
+
return center.flatten(), diagonal
|
58 |
+
|
59 |
+
cam_centers = []
|
60 |
+
|
61 |
+
for cam in cam_info:
|
62 |
+
W2C = getWorld2View2(cam.R, cam.T)
|
63 |
+
C2W = np.linalg.inv(W2C)
|
64 |
+
cam_centers.append(C2W[:3, 3:4])
|
65 |
+
|
66 |
+
center, diagonal = get_center_and_diag(cam_centers)
|
67 |
+
radius = diagonal * 1.1
|
68 |
+
|
69 |
+
translate = -center
|
70 |
+
|
71 |
+
return {"translate": translate, "radius": radius}
|
72 |
+
|
73 |
+
def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder, eval):
|
74 |
+
|
75 |
+
cam_infos = []
|
76 |
+
poses=[]
|
77 |
+
for idx, key in enumerate(cam_extrinsics):
|
78 |
+
sys.stdout.write('\r')
|
79 |
+
# the exact output you're looking for:
|
80 |
+
sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics)))
|
81 |
+
sys.stdout.flush()
|
82 |
+
|
83 |
+
if eval:
|
84 |
+
extr = cam_extrinsics[key]
|
85 |
+
intr = cam_intrinsics[1]
|
86 |
+
uid = idx+1
|
87 |
+
|
88 |
+
else:
|
89 |
+
extr = cam_extrinsics[key]
|
90 |
+
intr = cam_intrinsics[extr.camera_id]
|
91 |
+
uid = intr.id
|
92 |
+
|
93 |
+
height = intr.height
|
94 |
+
width = intr.width
|
95 |
+
R = np.transpose(qvec2rotmat(extr.qvec))
|
96 |
+
T = np.array(extr.tvec)
|
97 |
+
pose = np.vstack((np.hstack((R, T.reshape(3,-1))),np.array([[0, 0, 0, 1]])))
|
98 |
+
poses.append(pose)
|
99 |
+
if intr.model=="SIMPLE_PINHOLE":
|
100 |
+
focal_length_x = intr.params[0]
|
101 |
+
FovY = focal2fov(focal_length_x, height)
|
102 |
+
FovX = focal2fov(focal_length_x, width)
|
103 |
+
elif intr.model=="PINHOLE":
|
104 |
+
focal_length_x = intr.params[0]
|
105 |
+
focal_length_y = intr.params[1]
|
106 |
+
FovY = focal2fov(focal_length_y, height)
|
107 |
+
FovX = focal2fov(focal_length_x, width)
|
108 |
+
else:
|
109 |
+
assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
|
110 |
+
|
111 |
+
|
112 |
+
if eval:
|
113 |
+
tmp = os.path.dirname(os.path.dirname(os.path.join(images_folder)))
|
114 |
+
all_images_folder = os.path.join(tmp, 'images')
|
115 |
+
image_path = os.path.join(all_images_folder, os.path.basename(extr.name))
|
116 |
+
else:
|
117 |
+
image_path = os.path.join(images_folder, os.path.basename(extr.name))
|
118 |
+
image_name = os.path.basename(image_path).split(".")[0]
|
119 |
+
image = Image.open(image_path)
|
120 |
+
|
121 |
+
|
122 |
+
cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
|
123 |
+
image_path=image_path, image_name=image_name, width=width, height=height)
|
124 |
+
|
125 |
+
cam_infos.append(cam_info)
|
126 |
+
sys.stdout.write('\n')
|
127 |
+
return cam_infos, poses
|
128 |
+
|
129 |
+
# For interpolated video, open when only render interpolated video
|
130 |
+
def readColmapCamerasInterp(cam_extrinsics, cam_intrinsics, images_folder, model_path, cam_traj):
|
131 |
+
|
132 |
+
# pose_interpolated_path = model_path + 'pose/pose_interpolated.npy'
|
133 |
+
pose_interpolated_path = model_path + f'pose/pose_{cam_traj}.npy'
|
134 |
+
pose_interpolated = np.load(pose_interpolated_path)
|
135 |
+
intr = cam_intrinsics[1]
|
136 |
+
|
137 |
+
cam_infos = []
|
138 |
+
poses=[]
|
139 |
+
for idx, pose_npy in enumerate(pose_interpolated):
|
140 |
+
sys.stdout.write('\r')
|
141 |
+
sys.stdout.write("Reading camera {}/{}".format(idx+1, pose_interpolated.shape[0]))
|
142 |
+
sys.stdout.flush()
|
143 |
+
|
144 |
+
extr = pose_npy
|
145 |
+
intr = intr
|
146 |
+
height = intr.height
|
147 |
+
width = intr.width
|
148 |
+
|
149 |
+
uid = idx
|
150 |
+
R = extr[:3, :3].transpose()
|
151 |
+
T = extr[:3, 3]
|
152 |
+
pose = np.vstack((np.hstack((R, T.reshape(3,-1))),np.array([[0, 0, 0, 1]])))
|
153 |
+
# print(uid)
|
154 |
+
# print(pose.shape)
|
155 |
+
# pose = np.linalg.inv(pose)
|
156 |
+
poses.append(pose)
|
157 |
+
if intr.model=="SIMPLE_PINHOLE":
|
158 |
+
focal_length_x = intr.params[0]
|
159 |
+
FovY = focal2fov(focal_length_x, height)
|
160 |
+
FovX = focal2fov(focal_length_x, width)
|
161 |
+
elif intr.model=="PINHOLE":
|
162 |
+
focal_length_x = intr.params[0]
|
163 |
+
focal_length_y = intr.params[1]
|
164 |
+
FovY = focal2fov(focal_length_y, height)
|
165 |
+
FovX = focal2fov(focal_length_x, width)
|
166 |
+
else:
|
167 |
+
assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
|
168 |
+
|
169 |
+
images_list = os.listdir(os.path.join(images_folder))
|
170 |
+
image_name_0 = images_list[0]
|
171 |
+
image_name = str(idx).zfill(4)
|
172 |
+
image = Image.open(images_folder + '/' + image_name_0)
|
173 |
+
|
174 |
+
cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
|
175 |
+
image_path=images_folder, image_name=image_name, width=width, height=height)
|
176 |
+
cam_infos.append(cam_info)
|
177 |
+
|
178 |
+
sys.stdout.write('\n')
|
179 |
+
return cam_infos, poses
|
180 |
+
|
181 |
+
|
182 |
+
def fetchPly(path):
|
183 |
+
plydata = PlyData.read(path)
|
184 |
+
vertices = plydata['vertex']
|
185 |
+
positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
|
186 |
+
colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
|
187 |
+
normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
|
188 |
+
features = None
|
189 |
+
feat_keys = [key for key in vertices.data.dtype.names if key.startswith('feat_')]
|
190 |
+
if feat_keys:
|
191 |
+
features = np.vstack([vertices[key] for key in feat_keys]).T
|
192 |
+
return BasicPointCloud(points=positions, colors=colors, normals=normals, features=features)
|
193 |
+
|
194 |
+
def storePly(path, xyz, rgb):
|
195 |
+
# Define the dtype for the structured array
|
196 |
+
dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
|
197 |
+
('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
|
198 |
+
('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
|
199 |
+
|
200 |
+
normals = np.zeros_like(xyz)
|
201 |
+
|
202 |
+
elements = np.empty(xyz.shape[0], dtype=dtype)
|
203 |
+
attributes = np.concatenate((xyz, normals, rgb), axis=1)
|
204 |
+
elements[:] = list(map(tuple, attributes))
|
205 |
+
|
206 |
+
# Create the PlyData object and write to file
|
207 |
+
vertex_element = PlyElement.describe(elements, 'vertex')
|
208 |
+
ply_data = PlyData([vertex_element])
|
209 |
+
ply_data.write(path)
|
210 |
+
|
211 |
+
def readColmapSceneInfo(path, images, eval, args, opt, llffhold=2):
|
212 |
+
# try:
|
213 |
+
# cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
|
214 |
+
# cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
|
215 |
+
# cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
|
216 |
+
# cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
|
217 |
+
# except:
|
218 |
+
|
219 |
+
##### For initializing test pose using PCD_Registration
|
220 |
+
if eval and opt.get_video==False:
|
221 |
+
print("Loading initial test pose for evaluation.")
|
222 |
+
cameras_extrinsic_file = os.path.join(path, f"test_view/sparse/0/{opt.method}", "images.txt")
|
223 |
+
else:
|
224 |
+
cameras_extrinsic_file = os.path.join(path, f"sparse/0/{opt.method}", "images.txt")
|
225 |
+
|
226 |
+
cameras_intrinsic_file = os.path.join(path, f"sparse/0/{opt.method}", "cameras.txt")
|
227 |
+
if hasattr(opt, 'feat_type') and opt.feat_type is not None:
|
228 |
+
feat_type_str = '-'.join(opt.feat_type)
|
229 |
+
if "test_view" not in cameras_extrinsic_file:
|
230 |
+
cameras_extrinsic_file = cameras_extrinsic_file.replace("images.txt", f"{feat_type_str}/images.txt")
|
231 |
+
cameras_intrinsic_file = cameras_intrinsic_file.replace("cameras.txt", f"{feat_type_str}/cameras.txt")
|
232 |
+
cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
|
233 |
+
cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)
|
234 |
+
|
235 |
+
reading_dir = "images" if images == None else images
|
236 |
+
|
237 |
+
if opt.get_video:
|
238 |
+
cam_infos_unsorted, poses = readColmapCamerasInterp(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics,
|
239 |
+
images_folder=os.path.join(path, reading_dir),
|
240 |
+
model_path=args.model_path, cam_traj=opt.cam_traj)
|
241 |
+
else:
|
242 |
+
cam_infos_unsorted, poses = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir), eval=eval)
|
243 |
+
sorting_indices = sorted(range(len(cam_infos_unsorted)), key=lambda x: cam_infos_unsorted[x].image_name)
|
244 |
+
cam_infos = [cam_infos_unsorted[i] for i in sorting_indices]
|
245 |
+
sorted_poses = [poses[i] for i in sorting_indices]
|
246 |
+
cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name)
|
247 |
+
|
248 |
+
if eval:
|
249 |
+
# train_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx+1) % llffhold != 0]
|
250 |
+
# test_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx+1) % llffhold == 0]
|
251 |
+
# train_poses = [c for idx, c in enumerate(sorted_poses) if (idx+1) % llffhold != 0]
|
252 |
+
# test_poses = [c for idx, c in enumerate(sorted_poses) if (idx+1) % llffhold == 0]
|
253 |
+
|
254 |
+
train_cam_infos = cam_infos
|
255 |
+
test_cam_infos = cam_infos
|
256 |
+
train_poses = sorted_poses
|
257 |
+
test_poses = sorted_poses
|
258 |
+
|
259 |
+
else:
|
260 |
+
train_cam_infos = cam_infos
|
261 |
+
test_cam_infos = []
|
262 |
+
train_poses = sorted_poses
|
263 |
+
test_poses = []
|
264 |
+
|
265 |
+
# render_cam_infos = generate_ellipse_path_from_camera_infos(cam_infos)
|
266 |
+
|
267 |
+
nerf_normalization = getNerfppNorm(train_cam_infos)
|
268 |
+
|
269 |
+
ply_path = os.path.join(path, f"sparse/0/{opt.method}/points3D.ply")
|
270 |
+
if hasattr(opt, 'feat_type') and opt.feat_type is not None:
|
271 |
+
ply_path = ply_path.replace("points3D.ply", f"{feat_type_str}/points3D.ply")
|
272 |
+
bin_path = os.path.join(path, "sparse/0/points3D.bin")
|
273 |
+
txt_path = os.path.join(path, "sparse/0/points3D.txt")
|
274 |
+
if not os.path.exists(ply_path):
|
275 |
+
print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
|
276 |
+
try:
|
277 |
+
xyz, rgb, _ = read_points3D_binary(bin_path)
|
278 |
+
except:
|
279 |
+
xyz, rgb, _ = read_points3D_text(txt_path)
|
280 |
+
storePly(ply_path, xyz, rgb)
|
281 |
+
try:
|
282 |
+
pcd = fetchPly(ply_path)
|
283 |
+
except:
|
284 |
+
pcd = None
|
285 |
+
|
286 |
+
# np.save("poses_family.npy", sorted_poses)
|
287 |
+
# breakpoint()
|
288 |
+
# np.save("3dpoints.npy", pcd.points)
|
289 |
+
# np.save("3dcolors.npy", pcd.colors)
|
290 |
+
|
291 |
+
scene_info = SceneInfo(point_cloud=pcd,
|
292 |
+
train_cameras=train_cam_infos,
|
293 |
+
test_cameras=test_cam_infos,
|
294 |
+
# render_cameras=render_cam_infos,
|
295 |
+
nerf_normalization=nerf_normalization,
|
296 |
+
ply_path=ply_path,
|
297 |
+
train_poses=train_poses,
|
298 |
+
test_poses=test_poses)
|
299 |
+
return scene_info
|
300 |
+
|
301 |
+
def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"):
|
302 |
+
cam_infos = []
|
303 |
+
|
304 |
+
with open(os.path.join(path, transformsfile)) as json_file:
|
305 |
+
contents = json.load(json_file)
|
306 |
+
fovx = contents["camera_angle_x"]
|
307 |
+
|
308 |
+
frames = contents["frames"]
|
309 |
+
for idx, frame in enumerate(frames):
|
310 |
+
cam_name = os.path.join(path, frame["file_path"] + extension)
|
311 |
+
|
312 |
+
# NeRF 'transform_matrix' is a camera-to-world transform
|
313 |
+
c2w = np.array(frame["transform_matrix"])
|
314 |
+
# change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
|
315 |
+
c2w[:3, 1:3] *= -1
|
316 |
+
|
317 |
+
# get the world-to-camera transform and set R, T
|
318 |
+
w2c = np.linalg.inv(c2w)
|
319 |
+
R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code
|
320 |
+
T = w2c[:3, 3]
|
321 |
+
|
322 |
+
image_path = os.path.join(path, cam_name)
|
323 |
+
image_name = Path(cam_name).stem
|
324 |
+
image = Image.open(image_path)
|
325 |
+
|
326 |
+
im_data = np.array(image.convert("RGBA"))
|
327 |
+
|
328 |
+
bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
|
329 |
+
|
330 |
+
norm_data = im_data / 255.0
|
331 |
+
arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
|
332 |
+
image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
|
333 |
+
|
334 |
+
fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
|
335 |
+
FovY = fovy
|
336 |
+
FovX = fovx
|
337 |
+
|
338 |
+
cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
|
339 |
+
image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))
|
340 |
+
|
341 |
+
return cam_infos
|
342 |
+
|
343 |
+
def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
|
344 |
+
print("Reading Training Transforms")
|
345 |
+
train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension)
|
346 |
+
print("Reading Test Transforms")
|
347 |
+
test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension)
|
348 |
+
|
349 |
+
if not eval:
|
350 |
+
train_cam_infos.extend(test_cam_infos)
|
351 |
+
test_cam_infos = []
|
352 |
+
|
353 |
+
nerf_normalization = getNerfppNorm(train_cam_infos)
|
354 |
+
|
355 |
+
ply_path = os.path.join(path, "points3d.ply")
|
356 |
+
if not os.path.exists(ply_path):
|
357 |
+
# Since this data set has no colmap data, we start with random points
|
358 |
+
num_pts = 100_000
|
359 |
+
print(f"Generating random point cloud ({num_pts})...")
|
360 |
+
|
361 |
+
# We create random points inside the bounds of the synthetic Blender scenes
|
362 |
+
xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
|
363 |
+
shs = np.random.random((num_pts, 3)) / 255.0
|
364 |
+
pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))
|
365 |
+
|
366 |
+
storePly(ply_path, xyz, SH2RGB(shs) * 255)
|
367 |
+
try:
|
368 |
+
pcd = fetchPly(ply_path)
|
369 |
+
except:
|
370 |
+
pcd = None
|
371 |
+
|
372 |
+
scene_info = SceneInfo(point_cloud=pcd,
|
373 |
+
train_cameras=train_cam_infos,
|
374 |
+
test_cameras=test_cam_infos,
|
375 |
+
nerf_normalization=nerf_normalization,
|
376 |
+
ply_path=ply_path)
|
377 |
+
return scene_info
|
378 |
+
|
379 |
+
sceneLoadTypeCallbacks = {
|
380 |
+
"Colmap": readColmapSceneInfo,
|
381 |
+
"Blender" : readNerfSyntheticInfo
|
382 |
+
}
|
scene/gaussian_model.py
ADDED
@@ -0,0 +1,830 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
# from lietorch import SO3, SE3, Sim3, LieGroupParameter
|
14 |
+
import numpy as np
|
15 |
+
from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation
|
16 |
+
from torch import nn
|
17 |
+
import os
|
18 |
+
from utils.system_utils import mkdir_p
|
19 |
+
from plyfile import PlyData, PlyElement
|
20 |
+
from utils.sh_utils import RGB2SH
|
21 |
+
from simple_knn._C import distCUDA2
|
22 |
+
from utils.graphics_utils import BasicPointCloud
|
23 |
+
from utils.general_utils import strip_symmetric, build_scaling_rotation
|
24 |
+
from scipy.spatial.transform import Rotation as R
|
25 |
+
from utils.pose_utils import rotation2quad, get_tensor_from_camera
|
26 |
+
from utils.graphics_utils import getWorld2View2
|
27 |
+
|
28 |
+
import torch.nn.functional as F
|
29 |
+
|
30 |
+
def quaternion_to_rotation_matrix(quaternion):
|
31 |
+
"""
|
32 |
+
Convert a quaternion to a rotation matrix.
|
33 |
+
|
34 |
+
Parameters:
|
35 |
+
- quaternion: A tensor of shape (..., 4) representing quaternions.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
- A tensor of shape (..., 3, 3) representing rotation matrices.
|
39 |
+
"""
|
40 |
+
# Ensure quaternion is of float type for computation
|
41 |
+
quaternion = quaternion.float()
|
42 |
+
|
43 |
+
# Normalize the quaternion to unit length
|
44 |
+
quaternion = quaternion / quaternion.norm(p=2, dim=-1, keepdim=True)
|
45 |
+
|
46 |
+
# Extract components
|
47 |
+
w, x, y, z = quaternion[..., 0], quaternion[..., 1], quaternion[..., 2], quaternion[..., 3]
|
48 |
+
|
49 |
+
# Compute rotation matrix components
|
50 |
+
xx, yy, zz = x * x, y * y, z * z
|
51 |
+
xy, xz, yz = x * y, x * z, y * z
|
52 |
+
xw, yw, zw = x * w, y * w, z * w
|
53 |
+
|
54 |
+
# Assemble the rotation matrix
|
55 |
+
R = torch.stack([
|
56 |
+
torch.stack([1 - 2 * (yy + zz), 2 * (xy - zw), 2 * (xz + yw)], dim=-1),
|
57 |
+
torch.stack([ 2 * (xy + zw), 1 - 2 * (xx + zz), 2 * (yz - xw)], dim=-1),
|
58 |
+
torch.stack([ 2 * (xz - yw), 2 * (yz + xw), 1 - 2 * (xx + yy)], dim=-1)
|
59 |
+
], dim=-2)
|
60 |
+
|
61 |
+
return R
|
62 |
+
|
63 |
+
|
64 |
+
class GaussianModel:
|
65 |
+
|
66 |
+
def setup_functions(self):
|
67 |
+
def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
|
68 |
+
L = build_scaling_rotation(scaling_modifier * scaling, rotation)
|
69 |
+
actual_covariance = L @ L.transpose(1, 2)
|
70 |
+
symm = strip_symmetric(actual_covariance)
|
71 |
+
return symm
|
72 |
+
|
73 |
+
self.scaling_activation = torch.exp
|
74 |
+
self.scaling_inverse_activation = torch.log
|
75 |
+
|
76 |
+
self.covariance_activation = build_covariance_from_scaling_rotation
|
77 |
+
|
78 |
+
self.opacity_activation = torch.sigmoid
|
79 |
+
self.inverse_opacity_activation = inverse_sigmoid
|
80 |
+
|
81 |
+
self.rotation_activation = torch.nn.functional.normalize
|
82 |
+
|
83 |
+
|
84 |
+
def __init__(self, sh_degree : int):
|
85 |
+
# self.active_sh_degree = 0
|
86 |
+
self.active_sh_degree = sh_degree
|
87 |
+
self.max_sh_degree = sh_degree
|
88 |
+
self._xyz = torch.empty(0)
|
89 |
+
self._features_dc = torch.empty(0)
|
90 |
+
self._features_rest = torch.empty(0)
|
91 |
+
self._scaling = torch.empty(0)
|
92 |
+
self._rotation = torch.empty(0)
|
93 |
+
self._opacity = torch.empty(0)
|
94 |
+
self.max_radii2D = torch.empty(0)
|
95 |
+
self.xyz_gradient_accum = torch.empty(0)
|
96 |
+
self.denom = torch.empty(0)
|
97 |
+
self.optimizer = None
|
98 |
+
self.percent_dense = 0
|
99 |
+
self.spatial_lr_scale = 0
|
100 |
+
self.param_init = {}
|
101 |
+
self.setup_functions()
|
102 |
+
|
103 |
+
def capture(self):
|
104 |
+
return (
|
105 |
+
self.active_sh_degree,
|
106 |
+
self._xyz,
|
107 |
+
self._features_dc,
|
108 |
+
self._features_rest,
|
109 |
+
self._scaling,
|
110 |
+
self._rotation,
|
111 |
+
self._opacity,
|
112 |
+
self.max_radii2D,
|
113 |
+
self.xyz_gradient_accum,
|
114 |
+
self.denom,
|
115 |
+
self.optimizer.state_dict(),
|
116 |
+
self.spatial_lr_scale,
|
117 |
+
self.P,
|
118 |
+
)
|
119 |
+
|
120 |
+
def restore(self, model_args, training_args):
|
121 |
+
(self.active_sh_degree,
|
122 |
+
self._xyz,
|
123 |
+
self._features_dc,
|
124 |
+
self._features_rest,
|
125 |
+
self._scaling,
|
126 |
+
self._rotation,
|
127 |
+
self._opacity,
|
128 |
+
self.max_radii2D,
|
129 |
+
xyz_gradient_accum,
|
130 |
+
denom,
|
131 |
+
opt_dict,
|
132 |
+
self.spatial_lr_scale,
|
133 |
+
self.P) = model_args
|
134 |
+
self.training_setup(training_args)
|
135 |
+
self.xyz_gradient_accum = xyz_gradient_accum
|
136 |
+
self.denom = denom
|
137 |
+
self.optimizer.load_state_dict(opt_dict)
|
138 |
+
|
139 |
+
@property
|
140 |
+
def get_scaling(self):
|
141 |
+
return self.scaling_activation(self._scaling)
|
142 |
+
|
143 |
+
@property
|
144 |
+
def get_rotation(self):
|
145 |
+
return self.rotation_activation(self._rotation)
|
146 |
+
|
147 |
+
@property
|
148 |
+
def get_xyz(self):
|
149 |
+
return self._xyz
|
150 |
+
|
151 |
+
def compute_relative_world_to_camera(self, R1, t1, R2, t2):
|
152 |
+
# Create a row of zeros with a one at the end, for homogeneous coordinates
|
153 |
+
zero_row = np.array([[0, 0, 0, 1]], dtype=np.float32)
|
154 |
+
|
155 |
+
# Compute the inverse of the first extrinsic matrix
|
156 |
+
E1_inv = np.hstack([R1.T, -R1.T @ t1.reshape(-1, 1)]) # Transpose and reshape for correct dimensions
|
157 |
+
E1_inv = np.vstack([E1_inv, zero_row]) # Append the zero_row to make it a 4x4 matrix
|
158 |
+
|
159 |
+
# Compute the second extrinsic matrix
|
160 |
+
E2 = np.hstack([R2, -R2 @ t2.reshape(-1, 1)]) # No need to transpose R2
|
161 |
+
E2 = np.vstack([E2, zero_row]) # Append the zero_row to make it a 4x4 matrix
|
162 |
+
|
163 |
+
# Compute the relative transformation
|
164 |
+
E_rel = E2 @ E1_inv
|
165 |
+
|
166 |
+
return E_rel
|
167 |
+
|
168 |
+
def init_RT_seq(self, cam_list):
|
169 |
+
poses =[]
|
170 |
+
for cam in cam_list[1.0]:
|
171 |
+
p = get_tensor_from_camera(cam.world_view_transform.transpose(0, 1)) # R T -> quat t
|
172 |
+
poses.append(p)
|
173 |
+
poses = torch.stack(poses)
|
174 |
+
self.P = poses.cuda().requires_grad_(True)
|
175 |
+
# poses_ = torch.randn(poses.detach().clone().shape, device='cuda')
|
176 |
+
# self.P = poses_.cuda().requires_grad_(True)
|
177 |
+
self.param_init['pose'] = poses.detach().clone()
|
178 |
+
|
179 |
+
def get_RT(self, idx):
|
180 |
+
pose = self.P[idx]
|
181 |
+
return pose
|
182 |
+
|
183 |
+
def get_RT_test(self, idx):
|
184 |
+
pose = self.test_P[idx]
|
185 |
+
return pose
|
186 |
+
|
187 |
+
@property
|
188 |
+
def get_features(self):
|
189 |
+
features_dc = self._features_dc
|
190 |
+
features_rest = self._features_rest
|
191 |
+
return torch.cat((features_dc, features_rest), dim=1)
|
192 |
+
|
193 |
+
@property
|
194 |
+
def get_opacity(self):
|
195 |
+
return self.opacity_activation(self._opacity)
|
196 |
+
|
197 |
+
def get_covariance(self, scaling_modifier = 1):
|
198 |
+
return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
|
199 |
+
|
200 |
+
def oneupSHdegree(self):
|
201 |
+
if self.active_sh_degree < self.max_sh_degree:
|
202 |
+
self.active_sh_degree += 1
|
203 |
+
|
204 |
+
def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
|
205 |
+
self.spatial_lr_scale = spatial_lr_scale
|
206 |
+
fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
|
207 |
+
fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
|
208 |
+
features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
|
209 |
+
features[:, :3, 0 ] = fused_color
|
210 |
+
features[:, 3:, 1:] = 0.0
|
211 |
+
|
212 |
+
print("Number of points at initialisation : ", fused_point_cloud.shape[0])
|
213 |
+
|
214 |
+
dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
|
215 |
+
scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
|
216 |
+
rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
|
217 |
+
rots[:, 0] = 1
|
218 |
+
|
219 |
+
opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
|
220 |
+
|
221 |
+
self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
|
222 |
+
self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
|
223 |
+
self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
|
224 |
+
self._scaling = nn.Parameter(scales.requires_grad_(True))
|
225 |
+
self._rotation = nn.Parameter(rots.requires_grad_(True))
|
226 |
+
self._opacity = nn.Parameter(opacities.requires_grad_(True))
|
227 |
+
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
|
228 |
+
|
229 |
+
self.param_init.update({
|
230 |
+
'xyz': fused_point_cloud.detach().clone(),
|
231 |
+
'f_dc': features[:,:,0:1].transpose(1, 2).contiguous().detach().clone(),
|
232 |
+
'f_rest': features[:,:,1:].transpose(1, 2).contiguous().detach().clone(),
|
233 |
+
'opacity': opacities.detach().clone(),
|
234 |
+
'scaling': scales.detach().clone(),
|
235 |
+
'rotation': rots.detach().clone(),
|
236 |
+
})
|
237 |
+
def training_setup(self, training_args):
|
238 |
+
self.percent_dense = training_args.percent_dense
|
239 |
+
self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
240 |
+
self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
241 |
+
|
242 |
+
l = [
|
243 |
+
{'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
|
244 |
+
{'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
|
245 |
+
{'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
|
246 |
+
{'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
|
247 |
+
{'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"},
|
248 |
+
{'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"},
|
249 |
+
]
|
250 |
+
|
251 |
+
l_cam = [{'params': [self.P],'lr': training_args.rotation_lr*0.1, "name": "pose"},]
|
252 |
+
# l_cam = [{'params': [self.P],'lr': training_args.rotation_lr, "name": "pose"},]
|
253 |
+
|
254 |
+
l += l_cam
|
255 |
+
|
256 |
+
self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
|
257 |
+
self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
|
258 |
+
lr_final=training_args.position_lr_final*self.spatial_lr_scale,
|
259 |
+
lr_delay_mult=training_args.position_lr_delay_mult,
|
260 |
+
max_steps=training_args.position_lr_max_steps)
|
261 |
+
self.cam_scheduler_args = get_expon_lr_func(
|
262 |
+
# lr_init=0,
|
263 |
+
# lr_final=0,
|
264 |
+
lr_init=training_args.rotation_lr*0.1,
|
265 |
+
lr_final=training_args.rotation_lr*0.001,
|
266 |
+
# lr_init=training_args.position_lr_init*self.spatial_lr_scale*10,
|
267 |
+
# lr_final=training_args.position_lr_final*self.spatial_lr_scale*10,
|
268 |
+
lr_delay_mult=training_args.position_lr_delay_mult,
|
269 |
+
max_steps=1000)
|
270 |
+
|
271 |
+
def update_learning_rate(self, iteration):
|
272 |
+
''' Learning rate scheduling per step '''
|
273 |
+
for param_group in self.optimizer.param_groups:
|
274 |
+
if param_group["name"] == "pose":
|
275 |
+
lr = self.cam_scheduler_args(iteration)
|
276 |
+
# print("pose learning rate", iteration, lr)
|
277 |
+
param_group['lr'] = lr
|
278 |
+
if param_group["name"] == "xyz":
|
279 |
+
lr = self.xyz_scheduler_args(iteration)
|
280 |
+
param_group['lr'] = lr
|
281 |
+
# return lr
|
282 |
+
|
283 |
+
def construct_list_of_attributes(self):
|
284 |
+
l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
|
285 |
+
# All channels except the 3 DC
|
286 |
+
for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
|
287 |
+
l.append('f_dc_{}'.format(i))
|
288 |
+
for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
|
289 |
+
l.append('f_rest_{}'.format(i))
|
290 |
+
l.append('opacity')
|
291 |
+
for i in range(self._scaling.shape[1]):
|
292 |
+
l.append('scale_{}'.format(i))
|
293 |
+
for i in range(self._rotation.shape[1]):
|
294 |
+
l.append('rot_{}'.format(i))
|
295 |
+
return l
|
296 |
+
|
297 |
+
def save_ply(self, path):
|
298 |
+
mkdir_p(os.path.dirname(path))
|
299 |
+
|
300 |
+
xyz = self._xyz.detach().cpu().numpy()
|
301 |
+
normals = np.zeros_like(xyz)
|
302 |
+
f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
|
303 |
+
f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
|
304 |
+
opacities = self._opacity.detach().cpu().numpy()
|
305 |
+
scale = self._scaling.detach().cpu().numpy()
|
306 |
+
rotation = self._rotation.detach().cpu().numpy()
|
307 |
+
|
308 |
+
dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
|
309 |
+
|
310 |
+
elements = np.empty(xyz.shape[0], dtype=dtype_full)
|
311 |
+
attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
|
312 |
+
elements[:] = list(map(tuple, attributes))
|
313 |
+
el = PlyElement.describe(elements, 'vertex')
|
314 |
+
PlyData([el]).write(path)
|
315 |
+
|
316 |
+
def reset_opacity(self):
|
317 |
+
opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
|
318 |
+
optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
|
319 |
+
self._opacity = optimizable_tensors["opacity"]
|
320 |
+
|
321 |
+
def load_ply(self, path):
|
322 |
+
plydata = PlyData.read(path)
|
323 |
+
|
324 |
+
xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
|
325 |
+
np.asarray(plydata.elements[0]["y"]),
|
326 |
+
np.asarray(plydata.elements[0]["z"])), axis=1)
|
327 |
+
opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
|
328 |
+
|
329 |
+
features_dc = np.zeros((xyz.shape[0], 3, 1))
|
330 |
+
features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
|
331 |
+
features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
|
332 |
+
features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
|
333 |
+
|
334 |
+
extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
|
335 |
+
extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
|
336 |
+
assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
|
337 |
+
features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
|
338 |
+
for idx, attr_name in enumerate(extra_f_names):
|
339 |
+
features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
340 |
+
# Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
|
341 |
+
features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
|
342 |
+
|
343 |
+
scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
|
344 |
+
scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
|
345 |
+
scales = np.zeros((xyz.shape[0], len(scale_names)))
|
346 |
+
for idx, attr_name in enumerate(scale_names):
|
347 |
+
scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
348 |
+
|
349 |
+
rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
|
350 |
+
rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
|
351 |
+
rots = np.zeros((xyz.shape[0], len(rot_names)))
|
352 |
+
for idx, attr_name in enumerate(rot_names):
|
353 |
+
rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
354 |
+
|
355 |
+
self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
|
356 |
+
self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
|
357 |
+
self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
|
358 |
+
self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
|
359 |
+
self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
|
360 |
+
self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
|
361 |
+
|
362 |
+
self.active_sh_degree = self.max_sh_degree
|
363 |
+
|
364 |
+
def replace_tensor_to_optimizer(self, tensor, name):
|
365 |
+
optimizable_tensors = {}
|
366 |
+
for group in self.optimizer.param_groups:
|
367 |
+
if group["name"] == name:
|
368 |
+
# breakpoint()
|
369 |
+
stored_state = self.optimizer.state.get(group['params'][0], None)
|
370 |
+
stored_state["exp_avg"] = torch.zeros_like(tensor)
|
371 |
+
stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
|
372 |
+
|
373 |
+
del self.optimizer.state[group['params'][0]]
|
374 |
+
group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
|
375 |
+
self.optimizer.state[group['params'][0]] = stored_state
|
376 |
+
|
377 |
+
optimizable_tensors[group["name"]] = group["params"][0]
|
378 |
+
return optimizable_tensors
|
379 |
+
|
380 |
+
def _prune_optimizer(self, mask):
|
381 |
+
optimizable_tensors = {}
|
382 |
+
for group in self.optimizer.param_groups:
|
383 |
+
stored_state = self.optimizer.state.get(group['params'][0], None)
|
384 |
+
if stored_state is not None:
|
385 |
+
stored_state["exp_avg"] = stored_state["exp_avg"][mask]
|
386 |
+
stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
|
387 |
+
|
388 |
+
del self.optimizer.state[group['params'][0]]
|
389 |
+
group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
|
390 |
+
self.optimizer.state[group['params'][0]] = stored_state
|
391 |
+
|
392 |
+
optimizable_tensors[group["name"]] = group["params"][0]
|
393 |
+
else:
|
394 |
+
group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
|
395 |
+
optimizable_tensors[group["name"]] = group["params"][0]
|
396 |
+
return optimizable_tensors
|
397 |
+
|
398 |
+
def prune_points(self, mask):
|
399 |
+
valid_points_mask = ~mask
|
400 |
+
optimizable_tensors = self._prune_optimizer(valid_points_mask)
|
401 |
+
|
402 |
+
self._xyz = optimizable_tensors["xyz"]
|
403 |
+
self._features_dc = optimizable_tensors["f_dc"]
|
404 |
+
self._features_rest = optimizable_tensors["f_rest"]
|
405 |
+
self._opacity = optimizable_tensors["opacity"]
|
406 |
+
self._scaling = optimizable_tensors["scaling"]
|
407 |
+
self._rotation = optimizable_tensors["rotation"]
|
408 |
+
|
409 |
+
self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
|
410 |
+
|
411 |
+
self.denom = self.denom[valid_points_mask]
|
412 |
+
self.max_radii2D = self.max_radii2D[valid_points_mask]
|
413 |
+
|
414 |
+
def cat_tensors_to_optimizer(self, tensors_dict):
|
415 |
+
optimizable_tensors = {}
|
416 |
+
for group in self.optimizer.param_groups:
|
417 |
+
assert len(group["params"]) == 1
|
418 |
+
extension_tensor = tensors_dict[group["name"]]
|
419 |
+
stored_state = self.optimizer.state.get(group['params'][0], None)
|
420 |
+
if stored_state is not None:
|
421 |
+
|
422 |
+
stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
|
423 |
+
stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
|
424 |
+
|
425 |
+
del self.optimizer.state[group['params'][0]]
|
426 |
+
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
|
427 |
+
self.optimizer.state[group['params'][0]] = stored_state
|
428 |
+
|
429 |
+
optimizable_tensors[group["name"]] = group["params"][0]
|
430 |
+
else:
|
431 |
+
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
|
432 |
+
optimizable_tensors[group["name"]] = group["params"][0]
|
433 |
+
|
434 |
+
return optimizable_tensors
|
435 |
+
|
436 |
+
def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation):
|
437 |
+
d = {"xyz": new_xyz,
|
438 |
+
"f_dc": new_features_dc,
|
439 |
+
"f_rest": new_features_rest,
|
440 |
+
"opacity": new_opacities,
|
441 |
+
"scaling" : new_scaling,
|
442 |
+
"rotation" : new_rotation}
|
443 |
+
|
444 |
+
optimizable_tensors = self.cat_tensors_to_optimizer(d)
|
445 |
+
self._xyz = optimizable_tensors["xyz"]
|
446 |
+
self._features_dc = optimizable_tensors["f_dc"]
|
447 |
+
self._features_rest = optimizable_tensors["f_rest"]
|
448 |
+
self._opacity = optimizable_tensors["opacity"]
|
449 |
+
self._scaling = optimizable_tensors["scaling"]
|
450 |
+
self._rotation = optimizable_tensors["rotation"]
|
451 |
+
|
452 |
+
self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
453 |
+
self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
454 |
+
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
|
455 |
+
|
456 |
+
def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
|
457 |
+
n_init_points = self.get_xyz.shape[0]
|
458 |
+
# Extract points that satisfy the gradient condition
|
459 |
+
padded_grad = torch.zeros((n_init_points), device="cuda")
|
460 |
+
padded_grad[:grads.shape[0]] = grads.squeeze()
|
461 |
+
selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
|
462 |
+
selected_pts_mask = torch.logical_and(selected_pts_mask,
|
463 |
+
torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)
|
464 |
+
|
465 |
+
stds = self.get_scaling[selected_pts_mask].repeat(N,1)
|
466 |
+
means =torch.zeros((stds.size(0), 3),device="cuda")
|
467 |
+
samples = torch.normal(mean=means, std=stds)
|
468 |
+
rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
|
469 |
+
new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
|
470 |
+
new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
|
471 |
+
new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
|
472 |
+
new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
|
473 |
+
new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
|
474 |
+
new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
|
475 |
+
|
476 |
+
self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation)
|
477 |
+
|
478 |
+
prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
|
479 |
+
self.prune_points(prune_filter)
|
480 |
+
|
481 |
+
def densify_and_clone(self, grads, grad_threshold, scene_extent):
|
482 |
+
# Extract points that satisfy the gradient condition
|
483 |
+
selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
|
484 |
+
selected_pts_mask = torch.logical_and(selected_pts_mask,
|
485 |
+
torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)
|
486 |
+
|
487 |
+
new_xyz = self._xyz[selected_pts_mask]
|
488 |
+
new_features_dc = self._features_dc[selected_pts_mask]
|
489 |
+
new_features_rest = self._features_rest[selected_pts_mask]
|
490 |
+
new_opacities = self._opacity[selected_pts_mask]
|
491 |
+
new_scaling = self._scaling[selected_pts_mask]
|
492 |
+
new_rotation = self._rotation[selected_pts_mask]
|
493 |
+
|
494 |
+
self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation)
|
495 |
+
|
496 |
+
def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
|
497 |
+
grads = self.xyz_gradient_accum / self.denom
|
498 |
+
grads[grads.isnan()] = 0.0
|
499 |
+
|
500 |
+
# self.densify_and_clone(grads, max_grad, extent)
|
501 |
+
# self.densify_and_split(grads, max_grad, extent)
|
502 |
+
|
503 |
+
prune_mask = (self.get_opacity < min_opacity).squeeze()
|
504 |
+
if max_screen_size:
|
505 |
+
big_points_vs = self.max_radii2D > max_screen_size
|
506 |
+
big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
|
507 |
+
prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
|
508 |
+
self.prune_points(prune_mask)
|
509 |
+
|
510 |
+
torch.cuda.empty_cache()
|
511 |
+
|
512 |
+
def add_densification_stats(self, viewspace_point_tensor, update_filter):
|
513 |
+
self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)
|
514 |
+
self.denom[update_filter] += 1
|
515 |
+
|
516 |
+
|
517 |
+
class Feat2GaussianModel(GaussianModel):
|
518 |
+
|
519 |
+
def __init__(self, sh_degree : int, feat_dim : int, gs_params_group : dict, noise_std=0):
|
520 |
+
super().__init__(sh_degree)
|
521 |
+
self.noise_std = noise_std
|
522 |
+
self.pc_feat = torch.empty(0)
|
523 |
+
self.param_init = {}
|
524 |
+
self.feat_dim = feat_dim
|
525 |
+
self.gs_params_group = gs_params_group
|
526 |
+
self.active_sh_degree = sh_degree
|
527 |
+
self.sh_coeffs = ((sh_degree + 1) ** 2) * 3-3
|
528 |
+
net_width = feat_dim
|
529 |
+
out_dim = {'xyz': 3, 'scaling': 3, 'rotation': 4, 'opacity': 1, 'f_dc': 3, 'f_rest': self.sh_coeffs}
|
530 |
+
for key in gs_params_group.get('head', []):
|
531 |
+
setattr(self, f'head_{key}', conditionalWarp(layers=[feat_dim, net_width, out_dim[key]], skip=[]).cuda())
|
532 |
+
|
533 |
+
self.param_key = {
|
534 |
+
'xyz': '_xyz',
|
535 |
+
'scaling': '_scaling',
|
536 |
+
'rotation': '_rotation',
|
537 |
+
'opacity': '_opacity',
|
538 |
+
'f_dc': '_features_dc',
|
539 |
+
'f_rest': '_features_rest',
|
540 |
+
'pc_feat': 'pc_feat',
|
541 |
+
}
|
542 |
+
|
543 |
+
# ## FOR DEBUGGING
|
544 |
+
# self.head_xyz = conditionalWarp(layers=[self.feat_dim, net_width, 3], skip=[]).cuda()
|
545 |
+
# self.head_scaling = conditionalWarp(layers=[self.feat_dim, net_width, 3], skip=[]).cuda()
|
546 |
+
# self.head_rotation = conditionalWarp(layers=[self.feat_dim, net_width, 4], skip=[]).cuda()
|
547 |
+
# self.head_opacity = conditionalWarp(layers=[self.feat_dim, net_width, 1], skip=[]).cuda()
|
548 |
+
# self.head_f_dc = conditionalWarp(layers=[feat_dim, net_width, 3], skip=[]).cuda()
|
549 |
+
# self.head_f_rest = conditionalWarp(layers=[feat_dim, net_width, self.sh_coeffs], skip=[]).cuda()
|
550 |
+
|
551 |
+
def capture(self):
|
552 |
+
head_state_dicts = {f'head_{key}': getattr(self, f'head_{key}').state_dict() for key in self.gs_params_group.get('head', [])}
|
553 |
+
return (
|
554 |
+
self.active_sh_degree,
|
555 |
+
self._xyz,
|
556 |
+
self._features_dc,
|
557 |
+
self._features_rest,
|
558 |
+
self._scaling,
|
559 |
+
self._rotation,
|
560 |
+
self._opacity,
|
561 |
+
self.max_radii2D,
|
562 |
+
self.xyz_gradient_accum,
|
563 |
+
self.denom,
|
564 |
+
self.optimizer.state_dict(),
|
565 |
+
self.spatial_lr_scale,
|
566 |
+
self.P,
|
567 |
+
head_state_dicts
|
568 |
+
)
|
569 |
+
|
570 |
+
def restore(self, model_args, training_args):
|
571 |
+
(self.active_sh_degree,
|
572 |
+
self._xyz,
|
573 |
+
self._features_dc,
|
574 |
+
self._features_rest,
|
575 |
+
self._scaling,
|
576 |
+
self._rotation,
|
577 |
+
self._opacity,
|
578 |
+
self.max_radii2D,
|
579 |
+
xyz_gradient_accum,
|
580 |
+
denom,
|
581 |
+
opt_dict,
|
582 |
+
self.spatial_lr_scale,
|
583 |
+
self.P,
|
584 |
+
head_state_dicts
|
585 |
+
) = model_args
|
586 |
+
|
587 |
+
self.training_setup(training_args)
|
588 |
+
self.xyz_gradient_accum = xyz_gradient_accum
|
589 |
+
self.denom = denom
|
590 |
+
self.optimizer.load_state_dict(opt_dict)
|
591 |
+
|
592 |
+
for key, state_dict in head_state_dicts.items():
|
593 |
+
getattr(self, key).load_state_dict(state_dict)
|
594 |
+
|
595 |
+
def inference(self):
|
596 |
+
feat_in = self.pc_feat
|
597 |
+
for key in self.gs_params_group.get('head', []):
|
598 |
+
|
599 |
+
if key == 'f_dc':
|
600 |
+
self._features_dc = getattr(self, f'head_{key}')(feat_in, self.param_init[key].view(-1, 3)).reshape(-1, 1, 3)
|
601 |
+
elif key == 'f_rest':
|
602 |
+
self._features_rest = getattr(self, f'head_{key}')(feat_in.detach(), self.param_init[key].view(-1, self.sh_coeffs)).reshape(-1, self.sh_coeffs // 3, 3)
|
603 |
+
else:
|
604 |
+
setattr(self, f'_{key}', getattr(self, f'head_{key}')(feat_in, self.param_init[key]))
|
605 |
+
|
606 |
+
# if key == 'f_dc':
|
607 |
+
# self._features_dc = getattr(self, f'head_{key}')(feat_in, self.param_init[key].view(-1, 3)).reshape(-1, 1, 3)
|
608 |
+
# self._features_dc += self.param_init[key].view(-1, 1, 3).mean(dim=0, keepdim=True)
|
609 |
+
# elif key == 'f_rest':
|
610 |
+
# self._features_rest = getattr(self, f'head_{key}')(feat_in.detach(), self.param_init[key].view(-1, self.sh_coeffs)).reshape(-1, self.sh_coeffs // 3, 3)
|
611 |
+
# self._features_rest += self.param_init[key].view(-1, self.sh_coeffs // 3, 3).mean(dim=0, keepdim=True)
|
612 |
+
# else:
|
613 |
+
# pred = getattr(self, f'head_{key}')(feat_in, self.param_init[key])
|
614 |
+
# setattr(self, f'_{key}', pred + self.param_init[key].mean(dim=0, keepdim=True))
|
615 |
+
|
616 |
+
# ## FOR DEBUGGING
|
617 |
+
# self._xyz = self.head_xyz(pred, self.param_init['xyz'])
|
618 |
+
# self._opacity = self.head_opacity(pred, self.param_init['opacity'])
|
619 |
+
# self._scaling = self.head_scaling(pred, self.param_init['scaling'])
|
620 |
+
# self._rotation = self.head_rotation(pred, self.param_init['rotation'])
|
621 |
+
# self._features_dc = self.head_f_dc(pred, self.param_init['f_dc'].view(-1,3)).reshape(-1, 1, 3)
|
622 |
+
# self._features_rest = self.head_f_rest(pred, self.param_init['f_rest'].view(-1,self.sh_coeffs)).reshape(-1, self.sh_coeffs//3, 3)
|
623 |
+
|
624 |
+
def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
|
625 |
+
self.spatial_lr_scale = spatial_lr_scale
|
626 |
+
fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
|
627 |
+
fused_point_feat = torch.tensor(np.asarray(pcd.features)).float().cuda() # get features from .PLY file
|
628 |
+
assert fused_point_feat.shape[-1] == self.feat_dim, f"Expected feature dimension {self.feat_dim}, but got {fused_point_feat.shape[-1]}"
|
629 |
+
fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
|
630 |
+
features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
|
631 |
+
features[:, :3, 0 ] = fused_color
|
632 |
+
features[:, 3:, 1:] = 0.0
|
633 |
+
|
634 |
+
print("Number of points at initialisation : ", fused_point_cloud.shape[0])
|
635 |
+
|
636 |
+
dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
|
637 |
+
scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
|
638 |
+
rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
|
639 |
+
rots[:, 0] = 1
|
640 |
+
opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
|
641 |
+
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
|
642 |
+
|
643 |
+
self.pc_feat = fused_point_feat#.requires_grad_(True)
|
644 |
+
|
645 |
+
# fused_point_feat = torch.randn_like(fused_point_feat)
|
646 |
+
# self.pc_feat = fused_point_feat.requires_grad_(True)
|
647 |
+
|
648 |
+
self.gt_xyz = fused_point_cloud.clone()
|
649 |
+
if self.noise_std != 0:
|
650 |
+
self.noise_std /= 1000.0
|
651 |
+
torch.manual_seed(0)
|
652 |
+
torch.cuda.manual_seed(0)
|
653 |
+
noise = torch.randn_like(fused_point_cloud) * self.noise_std
|
654 |
+
fused_point_cloud += noise
|
655 |
+
# fused_point_cloud = noise + fused_point_cloud.mean(dim=0, keepdim=True)
|
656 |
+
# fused_point_cloud = torch.zeros_like(fused_point_cloud) + fused_point_cloud.mean(dim=0, keepdim=True)
|
657 |
+
|
658 |
+
param_init = {
|
659 |
+
'xyz': fused_point_cloud,
|
660 |
+
'scaling': scales,
|
661 |
+
'rotation': rots,
|
662 |
+
'opacity': opacities,
|
663 |
+
'f_dc': features[:, :, 0:1].transpose(1, 2).contiguous(),
|
664 |
+
'f_rest': features[:, :, 1:].transpose(1, 2).contiguous(),
|
665 |
+
'pc_feat': fused_point_feat,
|
666 |
+
}
|
667 |
+
|
668 |
+
for key in self.gs_params_group.get('opt', []):
|
669 |
+
setattr(self, self.param_key[key], nn.Parameter(param_init[key].requires_grad_(True)))
|
670 |
+
|
671 |
+
self.param_init.update({key: value.detach().clone() for key, value in param_init.items()})
|
672 |
+
|
673 |
+
# ## FOR DEBUGGING
|
674 |
+
# self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
|
675 |
+
# self._scaling = nn.Parameter(scales.requires_grad_(True))
|
676 |
+
# self._rotation = nn.Parameter(rots.requires_grad_(True))
|
677 |
+
# self._opacity = nn.Parameter(opacities.requires_grad_(True))
|
678 |
+
# self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
|
679 |
+
# self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
|
680 |
+
|
681 |
+
# self.param_init.update({
|
682 |
+
# 'xyz': fused_point_cloud.detach().clone(),
|
683 |
+
# 'f_dc': features[:,:,0:1].transpose(1, 2).contiguous().detach().clone(),
|
684 |
+
# 'f_rest': features[:,:,1:].transpose(1, 2).contiguous().detach().clone(),
|
685 |
+
# 'opacity': opacities.detach().clone(),
|
686 |
+
# 'scaling': scales.detach().clone(),
|
687 |
+
# 'rotation': rots.detach().clone(),
|
688 |
+
# 'pc_feat':fused_point_feat.detach().clone(),
|
689 |
+
# })
|
690 |
+
|
691 |
+
def training_setup(self, training_args):
|
692 |
+
self.percent_dense = training_args.percent_dense
|
693 |
+
self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
694 |
+
self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
695 |
+
|
696 |
+
self.param_lr = {
|
697 |
+
"xyz": training_args.position_lr_init * self.spatial_lr_scale,
|
698 |
+
"f_dc": training_args.feature_lr,
|
699 |
+
"f_rest": training_args.feature_sh_lr,
|
700 |
+
"opacity": training_args.opacity_lr,
|
701 |
+
"scaling": training_args.scaling_lr,
|
702 |
+
"rotation": training_args.rotation_lr
|
703 |
+
}
|
704 |
+
|
705 |
+
warm_start_lr = 0.01
|
706 |
+
l = []
|
707 |
+
for key in self.gs_params_group.get('head', []):
|
708 |
+
l.append({
|
709 |
+
'params': getattr(self, f'head_{key}').parameters(),
|
710 |
+
'lr': warm_start_lr,
|
711 |
+
'name': key
|
712 |
+
})
|
713 |
+
|
714 |
+
for key in self.gs_params_group.get('opt', []):
|
715 |
+
l.append({
|
716 |
+
'params': [getattr(self, self.param_key[key])],
|
717 |
+
'lr': warm_start_lr,
|
718 |
+
'name': key
|
719 |
+
})
|
720 |
+
|
721 |
+
# ## FOR DEBUGGING
|
722 |
+
# l += [
|
723 |
+
# {'params': self.head_f_dc.parameters(), 'lr': warm_start_lr, "name": "warm_start_f_dc"},
|
724 |
+
# {'params': self.head_f_rest.parameters(), 'lr': warm_start_lr, "name": "warm_start_f_rest"},
|
725 |
+
# ]
|
726 |
+
|
727 |
+
# l = [
|
728 |
+
# {'params': self.head_xyz.parameters(), 'lr': warm_start_lr, "name": "xyz"},
|
729 |
+
# # {'params': [self._xyz], 'lr': warm_start_lr, "name": "xyz"},
|
730 |
+
# {'params': self.head_scaling.parameters(), 'lr': warm_start_lr, "name": "scaling"},
|
731 |
+
# # {'params': [self._scaling], 'lr': warm_start_lr, "name": "scaling"},
|
732 |
+
# {'params': self.head_rotation.parameters(), 'lr': warm_start_lr, "name": "rotation"},
|
733 |
+
# # {'params': [self._rotation], 'lr': warm_start_lr, "name": "rotation"},
|
734 |
+
# {'params': self.head_opacity.parameters(), 'lr': warm_start_lr, "name": "opacity"},
|
735 |
+
# # {'params': [self._opacity], 'lr': warm_start_lr, "name": "opacity"},
|
736 |
+
# # {'params': self.head_f_dc.parameters(), 'lr': warm_start_lr, "name": "f_dc"},
|
737 |
+
# {'params': [self._features_dc], 'lr': warm_start_lr, "name": "f_dc"},
|
738 |
+
# # {'params': self.head_f_rest.parameters(), 'lr': warm_start_lr, "name": "f_rest"},
|
739 |
+
# {'params': [self._features_rest], 'lr': warm_start_lr, "name": "f_rest"},
|
740 |
+
# # {'params': [self.pc_feat], 'lr': warm_start_lr, "name": "feat"},
|
741 |
+
# ]
|
742 |
+
|
743 |
+
l_cam = [{'params': [self.P],'lr': training_args.pose_lr_init, "name": "pose"},]
|
744 |
+
|
745 |
+
l += l_cam
|
746 |
+
|
747 |
+
self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
|
748 |
+
self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init * self.spatial_lr_scale,
|
749 |
+
lr_final=training_args.position_lr_final * self.spatial_lr_scale,
|
750 |
+
lr_delay_mult=training_args.position_lr_delay_mult,
|
751 |
+
max_steps=training_args.position_lr_max_steps)
|
752 |
+
self.cam_scheduler_args = get_expon_lr_func(lr_init=training_args.pose_lr_init,
|
753 |
+
lr_final=training_args.pose_lr_final,
|
754 |
+
lr_delay_mult=training_args.position_lr_delay_mult,
|
755 |
+
max_steps=1000)
|
756 |
+
|
757 |
+
self.warm_start_scheduler_args = get_expon_lr_func(lr_init=warm_start_lr,
|
758 |
+
lr_final=warm_start_lr*0.01,
|
759 |
+
max_steps=1000)
|
760 |
+
|
761 |
+
def setup_rendering_learning_rate(self, ):
|
762 |
+
''' Setup learning rate scheduling'''
|
763 |
+
for param_group in self.optimizer.param_groups:
|
764 |
+
if param_group["name"] in self.param_lr:
|
765 |
+
param_group['lr'] = self.param_lr[param_group["name"]]
|
766 |
+
# elif param_group["name"] == "feat":
|
767 |
+
# param_group['lr'] = 1e-6
|
768 |
+
|
769 |
+
def update_warm_start_learning_rate(self, iteration):
|
770 |
+
''' Warm start learning rate scheduling per step '''
|
771 |
+
for param_group in self.optimizer.param_groups:
|
772 |
+
lr = self.warm_start_scheduler_args(iteration)
|
773 |
+
param_group['lr'] = lr
|
774 |
+
|
775 |
+
def update_learning_rate(self, iteration):
|
776 |
+
''' Learning rate scheduling per step '''
|
777 |
+
for param_group in self.optimizer.param_groups:
|
778 |
+
if param_group["name"] == "pose":
|
779 |
+
lr = self.cam_scheduler_args(iteration)
|
780 |
+
param_group['lr'] = lr
|
781 |
+
if param_group["name"] == "xyz":
|
782 |
+
lr = self.xyz_scheduler_args(iteration)
|
783 |
+
param_group['lr'] = lr
|
784 |
+
# return lr
|
785 |
+
|
786 |
+
class conditionalWarp(torch.nn.Module):
|
787 |
+
def __init__(self, layers, skip, skip_dim=None, res=[], freq=None, zero_init=False):
|
788 |
+
super().__init__()
|
789 |
+
self.skip = skip
|
790 |
+
self.res = res
|
791 |
+
self.freq = freq
|
792 |
+
self.mlp_warp = torch.nn.ModuleList()
|
793 |
+
L = self.get_layer_dims(layers)
|
794 |
+
for li,(k_in,k_out) in enumerate(L):
|
795 |
+
if li in self.skip: k_in += layers[-1] if skip_dim is None else skip_dim
|
796 |
+
linear = torch.nn.Linear(k_in,k_out)
|
797 |
+
|
798 |
+
# Init network output as 0
|
799 |
+
if zero_init:
|
800 |
+
if li == (len(L) - 1):
|
801 |
+
torch.nn.init.constant_(linear.weight, 0)
|
802 |
+
torch.nn.init.constant_(linear.bias, 0)
|
803 |
+
|
804 |
+
self.mlp_warp.append(linear)
|
805 |
+
|
806 |
+
def get_layer_dims(self, layers):
|
807 |
+
# return a list of tuples (k_in,k_out)
|
808 |
+
return list(zip(layers[:-1],layers[1:]))
|
809 |
+
|
810 |
+
def positional_encoding(self, input): # [B,...,N]
|
811 |
+
shape = input.shape
|
812 |
+
freq = 2**torch.arange(self.freq, dtype=torch.float32,device=input.device)*np.pi # [L]
|
813 |
+
spectrum = input[...,None]*freq # [B,...,N,L]
|
814 |
+
sin,cos = spectrum.sin(),spectrum.cos() # [B,...,N,L]
|
815 |
+
input_enc = torch.stack([sin,cos],dim=-2) # [B,...,N,2,L]
|
816 |
+
input_enc = input_enc.view(*shape[:-1],-1) # [B,...,2NL]
|
817 |
+
return input_enc
|
818 |
+
|
819 |
+
def forward(self, feat_in, color):
|
820 |
+
if self.freq != None:
|
821 |
+
feat_in = torch.cat([feat_in, self.positional_encoding(feat_in)], dim=-1)
|
822 |
+
feat = feat_in
|
823 |
+
for li,layer in enumerate(self.mlp_warp):
|
824 |
+
if li in self.skip: feat = torch.cat([feat, color],dim=-1)
|
825 |
+
if li in self.res: feat = feat + feat_in
|
826 |
+
feat = layer(feat)
|
827 |
+
if li!=len(self.mlp_warp)-1:
|
828 |
+
feat = nn.functional.relu(feat)
|
829 |
+
warp = feat
|
830 |
+
return warp
|
train_feat2gs.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import os
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
from random import randint
|
16 |
+
from utils.loss_utils import l1_loss, ssim
|
17 |
+
from gaussian_renderer import render_gsplat
|
18 |
+
import sys
|
19 |
+
from scene import Scene, Feat2GaussianModel
|
20 |
+
from argparse import ArgumentParser
|
21 |
+
from arguments import ModelParams, PipelineParams, OptimizationParams
|
22 |
+
from utils.pose_utils import get_camera_from_tensor
|
23 |
+
from tqdm import tqdm
|
24 |
+
|
25 |
+
from time import perf_counter
|
26 |
+
|
27 |
+
def save_pose(path, quat_pose, train_cams, llffhold=2):
|
28 |
+
output_poses=[]
|
29 |
+
index_colmap = [cam.colmap_id for cam in train_cams]
|
30 |
+
for quat_t in quat_pose:
|
31 |
+
w2c = get_camera_from_tensor(quat_t)
|
32 |
+
output_poses.append(w2c)
|
33 |
+
colmap_poses = []
|
34 |
+
for i in range(len(index_colmap)):
|
35 |
+
ind = index_colmap.index(i+1)
|
36 |
+
bb=output_poses[ind]
|
37 |
+
bb = bb#.inverse()
|
38 |
+
colmap_poses.append(bb)
|
39 |
+
colmap_poses = torch.stack(colmap_poses).detach().cpu().numpy()
|
40 |
+
np.save(path, colmap_poses)
|
41 |
+
|
42 |
+
|
43 |
+
def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from, args):
|
44 |
+
first_iter = 0
|
45 |
+
# tb_writer = prepare_output_and_logger(dataset, opt.iterations)
|
46 |
+
feat_type = '-'.join(args.feat_type)
|
47 |
+
feat_dim = args.feat_dim if feat_type not in ['iuv', 'iuvrgb'] else dataset.feat_default_dim[feat_type]
|
48 |
+
gs_params_group = dataset.gs_params_group[args.model]
|
49 |
+
gaussians = Feat2GaussianModel(dataset.sh_degree, feat_dim, gs_params_group)
|
50 |
+
scene = Scene(dataset, gaussians, opt=args, shuffle=True)
|
51 |
+
gaussians.training_setup(opt)
|
52 |
+
# if checkpoint:
|
53 |
+
# (model_params, first_iter) = torch.load(checkpoint)
|
54 |
+
# gaussians.restore(model_params, opt)
|
55 |
+
train_cams_init = scene.getTrainCameras().copy()
|
56 |
+
os.makedirs(scene.model_path + 'pose', exist_ok=True)
|
57 |
+
# save_pose(scene.model_path + 'pose' + "/pose_org.npy", gaussians.P, train_cams_init)
|
58 |
+
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
|
59 |
+
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
|
60 |
+
|
61 |
+
iter_start = torch.cuda.Event(enable_timing = True)
|
62 |
+
iter_end = torch.cuda.Event(enable_timing = True)
|
63 |
+
|
64 |
+
viewpoint_stack = None
|
65 |
+
ema_loss_for_log = 0.0
|
66 |
+
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
|
67 |
+
first_iter += 1
|
68 |
+
|
69 |
+
warm_iter = 1000
|
70 |
+
|
71 |
+
start = perf_counter()
|
72 |
+
for iteration in range(first_iter, opt.iterations + 1):
|
73 |
+
# if network_gui.conn == None:
|
74 |
+
# network_gui.try_connect()
|
75 |
+
# while network_gui.conn != None:
|
76 |
+
# try:
|
77 |
+
# net_image_bytes = None
|
78 |
+
# custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive()
|
79 |
+
# if custom_cam != None:
|
80 |
+
# net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"]
|
81 |
+
# net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy())
|
82 |
+
# network_gui.send(net_image_bytes, dataset.source_path)
|
83 |
+
# if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
|
84 |
+
# break
|
85 |
+
# except Exception as e:
|
86 |
+
# network_gui.conn = None
|
87 |
+
|
88 |
+
iter_start.record()
|
89 |
+
|
90 |
+
if iteration > warm_iter:
|
91 |
+
if iteration == warm_iter+1:
|
92 |
+
gaussians.pc_feat.requires_grad_(False)
|
93 |
+
gaussians.setup_rendering_learning_rate()
|
94 |
+
gaussians.update_learning_rate(iteration - warm_iter)
|
95 |
+
else:
|
96 |
+
gaussians.update_warm_start_learning_rate(iteration)
|
97 |
+
|
98 |
+
if args.optim_pose==False:
|
99 |
+
gaussians.P.requires_grad_(False)
|
100 |
+
|
101 |
+
# (DISABLED) Every 1000 its we increase the levels of SH up to a maximum degree
|
102 |
+
# if iteration % 1000 == 0:
|
103 |
+
# gaussians.oneupSHdegree()
|
104 |
+
|
105 |
+
# Pick a random Camera
|
106 |
+
if not viewpoint_stack:
|
107 |
+
viewpoint_stack = scene.getTrainCameras().copy()
|
108 |
+
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
|
109 |
+
pose = gaussians.get_RT(viewpoint_cam.uid)
|
110 |
+
|
111 |
+
# Render
|
112 |
+
if (iteration - 1) == debug_from:
|
113 |
+
pipe.debug = True
|
114 |
+
|
115 |
+
bg = torch.rand((3), device="cuda") if opt.random_background else background
|
116 |
+
|
117 |
+
gaussians.inference()
|
118 |
+
|
119 |
+
pretrained_loss_dict = {
|
120 |
+
'xyz': l1_loss(gaussians._xyz, gaussians.param_init['xyz']),
|
121 |
+
# 'f_dc': l1_loss(gaussians._features_dc, gaussians.param_init['f_dc']),
|
122 |
+
# 'f_rest': l1_loss(gaussians._features_rest, gaussians.param_init['f_rest']),
|
123 |
+
'opacity': l1_loss(gaussians._opacity, gaussians.param_init['opacity']),
|
124 |
+
'scaling': l1_loss(gaussians._scaling, gaussians.param_init['scaling']),
|
125 |
+
'rotation': l1_loss(gaussians._rotation, gaussians.param_init['rotation']),
|
126 |
+
# 'pose': l1_loss(gaussians.P, gaussians.param_init['pose']),
|
127 |
+
# 'focal': l1_loss(gaussians._focal_params, gaussians.param_init['focal']),
|
128 |
+
# 'pc_feat':l1_loss(gaussians.pc_feat, gaussians.param_init['pc_feat']),
|
129 |
+
}
|
130 |
+
|
131 |
+
if iteration <= warm_iter:
|
132 |
+
loss = sum(loss for key, loss in pretrained_loss_dict.items() if key in gs_params_group['head'])
|
133 |
+
Ll1 = torch.tensor(0)
|
134 |
+
|
135 |
+
if iteration > warm_iter:
|
136 |
+
render_pkg = render_gsplat(viewpoint_cam, gaussians, pipe, bg, camera_pose=pose)
|
137 |
+
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
|
138 |
+
|
139 |
+
# Loss
|
140 |
+
gt_image = viewpoint_cam.original_image.cuda()
|
141 |
+
Ll1 = l1_loss(image, gt_image)
|
142 |
+
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
|
143 |
+
|
144 |
+
# if feat_type in ['iuv', 'iuvrgb']:
|
145 |
+
# # Add scaling regularization for 'iuv' and 'iuvrgb' features
|
146 |
+
# # Prevents their gaussians scale from becoming too large to cause CUDA out of memory
|
147 |
+
# loss += l1_loss(gaussians._scaling, gaussians.param_init['scaling']) * 0.1
|
148 |
+
|
149 |
+
loss.backward()
|
150 |
+
iter_end.record()
|
151 |
+
|
152 |
+
with torch.no_grad():
|
153 |
+
|
154 |
+
# Progress bar
|
155 |
+
ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
|
156 |
+
if iteration % 10 == 0:
|
157 |
+
progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
|
158 |
+
progress_bar.update(10)
|
159 |
+
if iteration == opt.iterations:
|
160 |
+
progress_bar.close()
|
161 |
+
|
162 |
+
# Log and save
|
163 |
+
# training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render_gsplat, (pipe, background), pretrained_loss_dict)
|
164 |
+
if (iteration in saving_iterations):
|
165 |
+
print("\n[ITER {}] Saving Gaussians".format(iteration))
|
166 |
+
scene.save(iteration)
|
167 |
+
save_pose(scene.model_path + 'pose' + f"/pose_{iteration}.npy", gaussians.P, train_cams_init)
|
168 |
+
|
169 |
+
# (DISABLED) Densification
|
170 |
+
# if iteration < opt.densify_until_iter:
|
171 |
+
# Keep track of max radii in image-space for pruning
|
172 |
+
# gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
|
173 |
+
# gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
|
174 |
+
|
175 |
+
# if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
|
176 |
+
# size_threshold = 20 if iteration > opt.opacity_reset_interval else None
|
177 |
+
# gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
|
178 |
+
|
179 |
+
# if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
|
180 |
+
# gaussians.reset_opacity()
|
181 |
+
|
182 |
+
# Optimizer step
|
183 |
+
if iteration < opt.iterations:
|
184 |
+
gaussians.optimizer.step()
|
185 |
+
gaussians.optimizer.zero_grad(set_to_none = True)
|
186 |
+
|
187 |
+
# if (iteration in checkpoint_iterations):
|
188 |
+
# print("\n[ITER {}] Saving Checkpoint".format(iteration))
|
189 |
+
# torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
|
190 |
+
|
191 |
+
end = perf_counter()
|
192 |
+
train_time = end - start
|
193 |
+
|
194 |
+
# We commented out log&save operations, and then calculate train time.
|
195 |
+
# train_time = np.array(train_time)
|
196 |
+
# print("total_test_time_epoch: ", 1)
|
197 |
+
# print("train_time_mean: ", train_time.mean())
|
198 |
+
# print("train_time_median: ", np.median(train_time))
|
199 |
+
|
200 |
+
|
201 |
+
if __name__ == "__main__":
|
202 |
+
# Set up command line argument parser
|
203 |
+
parser = ArgumentParser(description="Training script parameters")
|
204 |
+
lp = ModelParams(parser)
|
205 |
+
op = OptimizationParams(parser)
|
206 |
+
pp = PipelineParams(parser)
|
207 |
+
parser.add_argument('--ip', type=str, default="127.0.0.1")
|
208 |
+
parser.add_argument('--port', type=int, default=6009)
|
209 |
+
parser.add_argument('--debug_from', type=int, default=-1)
|
210 |
+
parser.add_argument('--detect_anomaly', action='store_true', default=False)
|
211 |
+
parser.add_argument("--test_iterations", nargs="+", type=int,
|
212 |
+
default=[500, 800, 1000, 1500, 2000, 3000, 4000, 5000, 6000, 7_000, \
|
213 |
+
8_000, 9_000, 10_000, 11_000, 12_000, 13_000, 14_000, 30_000])
|
214 |
+
parser.add_argument("--save_iterations", nargs="+", type=int, default=[])
|
215 |
+
parser.add_argument("--quiet", action="store_true")
|
216 |
+
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
|
217 |
+
parser.add_argument("--start_checkpoint", type=str, default = None)
|
218 |
+
parser.add_argument("--scene", type=str, default=None)
|
219 |
+
parser.add_argument("--n_views", type=int, default=None)
|
220 |
+
parser.add_argument("--get_video", action="store_true")
|
221 |
+
parser.add_argument("--optim_pose", action="store_true")
|
222 |
+
parser.add_argument("--feat_type", type=str, nargs='*', default=None, help="Feature type(s). Multiple types can be specified for combination.")
|
223 |
+
parser.add_argument("--method", type=str, default='dust3r', help="Method of Initialization, e.g., 'dust3r' or 'mast3r'")
|
224 |
+
parser.add_argument("--feat_dim", type=int, default=None, help="Feture dimension after PCA . If None, PCA is not applied.")
|
225 |
+
parser.add_argument("--model", type=str, default='G', help="Model of Feat2gs, 'G'='geometry'/'T'='texture'/'A'='all'")
|
226 |
+
|
227 |
+
args = parser.parse_args(sys.argv[1:])
|
228 |
+
args.save_iterations.append(args.iterations)
|
229 |
+
|
230 |
+
os.makedirs(args.model_path, exist_ok=True)
|
231 |
+
|
232 |
+
print("Optimizing " + args.model_path)
|
233 |
+
|
234 |
+
# Initialize system state (RNG)
|
235 |
+
# safe_state(args.quiet)
|
236 |
+
|
237 |
+
# Start GUI server, configure and run training
|
238 |
+
# network_gui.init(args.ip, args.port)
|
239 |
+
torch.autograd.set_detect_anomaly(args.detect_anomaly)
|
240 |
+
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, args)
|
241 |
+
|
242 |
+
# All done
|
243 |
+
print("\nTraining complete.")
|
utils/camera_traj_config.py
ADDED
@@ -0,0 +1,655 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
trajectory_configs = {
|
2 |
+
'Infer': {
|
3 |
+
'cy': {
|
4 |
+
'up': [-1, 1],
|
5 |
+
'arc': {
|
6 |
+
'degree': 180.0, # Default arc degree
|
7 |
+
},
|
8 |
+
'spiral': {
|
9 |
+
'zrate': 0.5, # Spiral rise rate
|
10 |
+
'rots': 1, # Number of rotations
|
11 |
+
},
|
12 |
+
'lemniscate': {
|
13 |
+
'degree': 45.0, # Lemniscate curve angle
|
14 |
+
},
|
15 |
+
'wander': {
|
16 |
+
'max_disp': 48.0, # Maximum displacement
|
17 |
+
}
|
18 |
+
},
|
19 |
+
'paper4': {
|
20 |
+
'up': [-1, 1],
|
21 |
+
'arc': {
|
22 |
+
'degree': 200.0, # Default arc degree
|
23 |
+
},
|
24 |
+
'spiral': {
|
25 |
+
'zrate': 0.5, # Spiral rise rate
|
26 |
+
'rots': 1, # Number of rotations
|
27 |
+
},
|
28 |
+
'lemniscate': {
|
29 |
+
'degree': 80.0, # Lemniscate curve angle
|
30 |
+
},
|
31 |
+
'wander': {
|
32 |
+
'max_disp': 100.0, # Maximum displacement
|
33 |
+
}
|
34 |
+
},
|
35 |
+
'cy3': {
|
36 |
+
'up': [-1, 1],
|
37 |
+
'arc': {
|
38 |
+
'degree': 180.0, # Default arc degree
|
39 |
+
},
|
40 |
+
'spiral': {
|
41 |
+
'zrate': 0.5, # Spiral rise rate
|
42 |
+
'rots': 1, # Number of rotations
|
43 |
+
},
|
44 |
+
'lemniscate': {
|
45 |
+
'degree': 45.0, # Lemniscate curve angle
|
46 |
+
},
|
47 |
+
'wander': {
|
48 |
+
'max_disp': 48.0, # Maximum displacement
|
49 |
+
}
|
50 |
+
},
|
51 |
+
'cy4': {
|
52 |
+
'up': [-1, 1],
|
53 |
+
'arc': {
|
54 |
+
'degree': 180.0, # Default arc degree
|
55 |
+
},
|
56 |
+
'spiral': {
|
57 |
+
'zrate': 0.5, # Spiral rise rate
|
58 |
+
'rots': 1, # Number of rotations
|
59 |
+
},
|
60 |
+
'lemniscate': {
|
61 |
+
'degree': 45.0, # Lemniscate curve angle
|
62 |
+
},
|
63 |
+
'wander': {
|
64 |
+
'max_disp': 48.0, # Maximum displacement
|
65 |
+
}
|
66 |
+
},
|
67 |
+
'coffee': {
|
68 |
+
'up': [-1, 0],
|
69 |
+
'arc': {
|
70 |
+
'degree': 180.0, # Default arc degree
|
71 |
+
},
|
72 |
+
'spiral': {
|
73 |
+
'zrate': 0.5, # Spiral rise rate
|
74 |
+
'rots': 1, # Number of rotations
|
75 |
+
},
|
76 |
+
'lemniscate': {
|
77 |
+
'degree': 80.0, # Lemniscate curve angle
|
78 |
+
},
|
79 |
+
'wander': {
|
80 |
+
'max_disp': 48.0, # Maximum displacement
|
81 |
+
}
|
82 |
+
},
|
83 |
+
|
84 |
+
'plant': {
|
85 |
+
'up': [-1, 1],
|
86 |
+
'arc': {
|
87 |
+
'degree': 180.0, # Default arc degree
|
88 |
+
},
|
89 |
+
'spiral': {
|
90 |
+
'zrate': 0.5, # Spiral rise rate
|
91 |
+
'rots': 1, # Number of rotations
|
92 |
+
},
|
93 |
+
'lemniscate': {
|
94 |
+
'degree': 80.0, # Lemniscate curve angle
|
95 |
+
},
|
96 |
+
'wander': {
|
97 |
+
'max_disp': 48.0, # Maximum displacement
|
98 |
+
}
|
99 |
+
},
|
100 |
+
|
101 |
+
'desk': {
|
102 |
+
'up': [-1, 1],
|
103 |
+
'arc': {
|
104 |
+
'degree': 180.0, # Default arc degree
|
105 |
+
},
|
106 |
+
},
|
107 |
+
|
108 |
+
'bread': {
|
109 |
+
'up': [-1, 0],
|
110 |
+
'arc': {
|
111 |
+
'degree': 180.0, # Default arc degree
|
112 |
+
},
|
113 |
+
'spiral': {
|
114 |
+
'zrate': 0.5, # Spiral rise rate
|
115 |
+
'rots': 1, # Number of rotations
|
116 |
+
},
|
117 |
+
'lemniscate': {
|
118 |
+
'degree': 80.0, # Lemniscate curve angle
|
119 |
+
},
|
120 |
+
'wander': {
|
121 |
+
'max_disp': 48.0, # Maximum displacement
|
122 |
+
}
|
123 |
+
},
|
124 |
+
|
125 |
+
'brunch': {
|
126 |
+
'up': [1, 1],
|
127 |
+
'arc': {
|
128 |
+
'degree': 180.0, # Default arc degree
|
129 |
+
},
|
130 |
+
'spiral': {
|
131 |
+
'zrate': 0.5, # Spiral rise rate
|
132 |
+
'rots': 1, # Number of rotations
|
133 |
+
},
|
134 |
+
'lemniscate': {
|
135 |
+
'degree': 80.0, # Lemniscate curve angle
|
136 |
+
},
|
137 |
+
'wander': {
|
138 |
+
'max_disp': 48.0, # Maximum displacement
|
139 |
+
}
|
140 |
+
},
|
141 |
+
|
142 |
+
'stuff': {
|
143 |
+
'up': [-1, 0],
|
144 |
+
'arc': {
|
145 |
+
'degree': 180.0, # Default arc degree
|
146 |
+
},
|
147 |
+
'spiral': {
|
148 |
+
'zrate': 0.5, # Spiral rise rate
|
149 |
+
'rots': 1, # Number of rotations
|
150 |
+
},
|
151 |
+
'lemniscate': {
|
152 |
+
'degree': 80.0, # Lemniscate curve angle
|
153 |
+
},
|
154 |
+
'wander': {
|
155 |
+
'max_disp': 48.0, # Maximum displacement
|
156 |
+
}
|
157 |
+
},
|
158 |
+
|
159 |
+
'xbox': {
|
160 |
+
'up': [-1, 1],
|
161 |
+
'arc': {
|
162 |
+
'degree': 180.0, # Default arc degree
|
163 |
+
},
|
164 |
+
'spiral': {
|
165 |
+
'zrate': 0.5, # Spiral rise rate
|
166 |
+
'rots': 1, # Number of rotations
|
167 |
+
},
|
168 |
+
'lemniscate': {
|
169 |
+
'degree': 80.0, # Lemniscate curve angle
|
170 |
+
},
|
171 |
+
'wander': {
|
172 |
+
'max_disp': 48.0, # Maximum displacement
|
173 |
+
}
|
174 |
+
},
|
175 |
+
|
176 |
+
'plushies': {
|
177 |
+
'up': [-1, 1],
|
178 |
+
'arc': {
|
179 |
+
'degree': 120.0, # Default arc degree
|
180 |
+
},
|
181 |
+
'spiral': {
|
182 |
+
'zrate': 0.5, # Spiral rise rate
|
183 |
+
'rots': 1, # Number of rotations
|
184 |
+
},
|
185 |
+
'lemniscate': {
|
186 |
+
'degree': 60.0, # Lemniscate curve angle
|
187 |
+
},
|
188 |
+
'wander': {
|
189 |
+
'max_disp': 48.0, # Maximum displacement
|
190 |
+
}
|
191 |
+
},
|
192 |
+
|
193 |
+
'erhai': {
|
194 |
+
'up': [-1, 1],
|
195 |
+
'arc': {
|
196 |
+
'degree': 180.0, # Default arc degree
|
197 |
+
},
|
198 |
+
'lemniscate': {
|
199 |
+
'degree': 80.0, # Lemniscate curve angle
|
200 |
+
},
|
201 |
+
'wander': {
|
202 |
+
'max_disp': 200.0, # Maximum displacement
|
203 |
+
}
|
204 |
+
},
|
205 |
+
'cy_crop1': {
|
206 |
+
'up': [-1, 0],
|
207 |
+
'lemniscate': {
|
208 |
+
'degree': 80.0, # Lemniscate curve angle
|
209 |
+
},
|
210 |
+
'wander': {
|
211 |
+
'max_disp': 60.0, # Maximum displacement
|
212 |
+
}
|
213 |
+
},
|
214 |
+
|
215 |
+
'cy_crop': {
|
216 |
+
'up': [-1, 1],
|
217 |
+
'lemniscate': {
|
218 |
+
'degree': 60.0, # Lemniscate curve angle
|
219 |
+
},
|
220 |
+
'wander': {
|
221 |
+
'max_disp': 48.0, # Maximum displacement
|
222 |
+
}
|
223 |
+
},
|
224 |
+
|
225 |
+
'paper': {
|
226 |
+
'up': [-1, 1],
|
227 |
+
'arc': {
|
228 |
+
'degree': 240.0, # Default arc degree
|
229 |
+
},
|
230 |
+
'spiral': {
|
231 |
+
'zrate': 0.5, # Spiral rise rate
|
232 |
+
'rots': 1, # Number of rotations
|
233 |
+
},
|
234 |
+
'lemniscate': {
|
235 |
+
'degree': 60.0, # Lemniscate curve angle
|
236 |
+
},
|
237 |
+
'wander': {
|
238 |
+
'max_disp': 48.0, # Maximum displacement
|
239 |
+
}
|
240 |
+
},
|
241 |
+
|
242 |
+
'house': {
|
243 |
+
'up': [-1, 1],
|
244 |
+
'arc': {
|
245 |
+
'degree': 240.0, # Default arc degree
|
246 |
+
},
|
247 |
+
'spiral': {
|
248 |
+
'zrate': 0.5, # Spiral rise rate
|
249 |
+
'rots': 1, # Number of rotations
|
250 |
+
},
|
251 |
+
'lemniscate': {
|
252 |
+
'degree': 60.0, # Lemniscate curve angle
|
253 |
+
},
|
254 |
+
},
|
255 |
+
|
256 |
+
'home': {
|
257 |
+
'up': [1, 0],
|
258 |
+
'arc': {
|
259 |
+
'degree': 90.0, # Default arc degree
|
260 |
+
},
|
261 |
+
'spiral': {
|
262 |
+
'zrate': 0.5, # Spiral rise rate
|
263 |
+
'rots': 1, # Number of rotations
|
264 |
+
},
|
265 |
+
'lemniscate': {
|
266 |
+
'degree': 80.0, # Lemniscate curve angle
|
267 |
+
}
|
268 |
+
},
|
269 |
+
|
270 |
+
'paper3': {
|
271 |
+
'up': [1, 0],
|
272 |
+
'arc': {
|
273 |
+
'degree': 180.0, # Default arc degree
|
274 |
+
},
|
275 |
+
'spiral': {
|
276 |
+
'zrate': 0.5, # Spiral rise rate
|
277 |
+
'rots': 1, # Number of rotations
|
278 |
+
},
|
279 |
+
'lemniscate': {
|
280 |
+
'degree': 80.0, # Lemniscate curve angle
|
281 |
+
},
|
282 |
+
'wander': {
|
283 |
+
'max_disp': 48.0, # Maximum displacement
|
284 |
+
}
|
285 |
+
},
|
286 |
+
|
287 |
+
'castle': {
|
288 |
+
'up': [-1, 1],
|
289 |
+
'spiral': {
|
290 |
+
'zrate': 2, # Spiral rise rate
|
291 |
+
'rots': 2, # Number of rotations
|
292 |
+
},
|
293 |
+
},
|
294 |
+
|
295 |
+
'hogwarts': {
|
296 |
+
'up': [-1, 1],
|
297 |
+
'spiral': {
|
298 |
+
'zrate': 0.5, # Spiral rise rate
|
299 |
+
'rots': 1, # Number of rotations
|
300 |
+
},
|
301 |
+
'wander': {
|
302 |
+
'max_disp': 100.0, # Maximum displacement
|
303 |
+
}
|
304 |
+
},
|
305 |
+
},
|
306 |
+
'Tanks': {
|
307 |
+
'Auditorium': {
|
308 |
+
'up': [-1, 1],
|
309 |
+
'arc': {
|
310 |
+
'degree': 180.0, # Default arc degree
|
311 |
+
},
|
312 |
+
'spiral': {
|
313 |
+
'zrate': 0.5, # Spiral rise rate
|
314 |
+
'rots': 1, # Number of rotations
|
315 |
+
},
|
316 |
+
'lemniscate': {
|
317 |
+
'degree': 30.0, # Lemniscate curve angle
|
318 |
+
},
|
319 |
+
'wander': {
|
320 |
+
'max_disp': 80.0, # Maximum displacement
|
321 |
+
}
|
322 |
+
},
|
323 |
+
'Caterpillar': {
|
324 |
+
'up': [-1, 1],
|
325 |
+
'arc': {
|
326 |
+
'degree': 240.0, # Default arc degree
|
327 |
+
},
|
328 |
+
'spiral': {
|
329 |
+
'zrate': 0.5, # Spiral rise rate
|
330 |
+
'rots': 1, # Number of rotations
|
331 |
+
},
|
332 |
+
'lemniscate': {
|
333 |
+
'degree': 60.0, # Lemniscate curve angle
|
334 |
+
},
|
335 |
+
'wander': {
|
336 |
+
'max_disp': 48.0, # Maximum displacement
|
337 |
+
}
|
338 |
+
},
|
339 |
+
'Family': {
|
340 |
+
'up': [-1, 1],
|
341 |
+
'arc': {
|
342 |
+
'degree': 180.0, # Default arc degree
|
343 |
+
},
|
344 |
+
'spiral': {
|
345 |
+
'zrate': 0.5, # Spiral rise rate
|
346 |
+
'rots': 1, # Number of rotations
|
347 |
+
},
|
348 |
+
'lemniscate': {
|
349 |
+
'degree': 60.0, # Lemniscate curve angle
|
350 |
+
},
|
351 |
+
'wander': {
|
352 |
+
'max_disp': 48.0, # Maximum displacement
|
353 |
+
}
|
354 |
+
},
|
355 |
+
'Ignatius': {
|
356 |
+
'up': [-1, 1],
|
357 |
+
'arc': {
|
358 |
+
'degree': 330.0, # Default arc degree
|
359 |
+
},
|
360 |
+
'spiral': {
|
361 |
+
'zrate': 0.5, # Spiral rise rate
|
362 |
+
'rots': 1, # Number of rotations
|
363 |
+
},
|
364 |
+
'lemniscate': {
|
365 |
+
'degree': 80.0, # Lemniscate curve angle
|
366 |
+
},
|
367 |
+
'wander': {
|
368 |
+
'max_disp': 48.0, # Maximum displacement
|
369 |
+
}
|
370 |
+
},
|
371 |
+
'Train': {
|
372 |
+
'up': [-1, 1],
|
373 |
+
'arc': {
|
374 |
+
'degree': 180.0, # Default arc degree
|
375 |
+
},
|
376 |
+
'spiral': {
|
377 |
+
'zrate': 0.5, # Spiral rise rate
|
378 |
+
'rots': 1, # Number of rotations
|
379 |
+
},
|
380 |
+
'lemniscate': {
|
381 |
+
'degree': 60.0, # Lemniscate curve angle
|
382 |
+
},
|
383 |
+
'wander': {
|
384 |
+
'max_disp': 48.0, # Maximum displacement
|
385 |
+
}
|
386 |
+
},
|
387 |
+
|
388 |
+
},
|
389 |
+
'DL3DV': {
|
390 |
+
'Center': {
|
391 |
+
'up': [-1, 1],
|
392 |
+
'arc': {
|
393 |
+
'degree': 180.0, # Default arc degree
|
394 |
+
},
|
395 |
+
'spiral': {
|
396 |
+
'zrate': 0.5, # Spiral rise rate
|
397 |
+
'rots': 1, # Number of rotations
|
398 |
+
},
|
399 |
+
'lemniscate': {
|
400 |
+
'degree': 60.0, # Lemniscate curve angle
|
401 |
+
},
|
402 |
+
'wander': {
|
403 |
+
'max_disp': 48.0, # Maximum displacement
|
404 |
+
}
|
405 |
+
},
|
406 |
+
'Electrical': {
|
407 |
+
'up': [-1, 1],
|
408 |
+
'arc': {
|
409 |
+
'degree': 180.0, # Default arc degree
|
410 |
+
},
|
411 |
+
'spiral': {
|
412 |
+
'zrate': 0.5, # Spiral rise rate
|
413 |
+
'rots': 1, # Number of rotations
|
414 |
+
},
|
415 |
+
'lemniscate': {
|
416 |
+
'degree': 60.0, # Lemniscate curve angle
|
417 |
+
},
|
418 |
+
'wander': {
|
419 |
+
'max_disp': 48.0, # Maximum displacement
|
420 |
+
}
|
421 |
+
},
|
422 |
+
'Museum': {
|
423 |
+
'up': [-1, 1],
|
424 |
+
'arc': {
|
425 |
+
'degree': 180.0, # Default arc degree
|
426 |
+
},
|
427 |
+
'spiral': {
|
428 |
+
'zrate': 0.5, # Spiral rise rate
|
429 |
+
'rots': 1, # Number of rotations
|
430 |
+
},
|
431 |
+
'lemniscate': {
|
432 |
+
'degree': 60.0, # Lemniscate curve angle
|
433 |
+
},
|
434 |
+
'wander': {
|
435 |
+
'max_disp': 48.0, # Maximum displacement
|
436 |
+
}
|
437 |
+
},
|
438 |
+
'Supermarket2': {
|
439 |
+
'up': [-1, 1],
|
440 |
+
'arc': {
|
441 |
+
'degree': 180.0, # Default arc degree
|
442 |
+
},
|
443 |
+
'spiral': {
|
444 |
+
'zrate': 0.5, # Spiral rise rate
|
445 |
+
'rots': 1, # Number of rotations
|
446 |
+
},
|
447 |
+
'lemniscate': {
|
448 |
+
'degree': 60.0, # Lemniscate curve angle
|
449 |
+
},
|
450 |
+
'wander': {
|
451 |
+
'max_disp': 48.0, # Maximum displacement
|
452 |
+
}
|
453 |
+
},
|
454 |
+
'Temple': {
|
455 |
+
'up': [-1, 1],
|
456 |
+
'arc': {
|
457 |
+
'degree': 180.0, # Default arc degree
|
458 |
+
},
|
459 |
+
'spiral': {
|
460 |
+
'zrate': 0.5, # Spiral rise rate
|
461 |
+
'rots': 1, # Number of rotations
|
462 |
+
},
|
463 |
+
'lemniscate': {
|
464 |
+
'degree': 60.0, # Lemniscate curve angle
|
465 |
+
},
|
466 |
+
'wander': {
|
467 |
+
'max_disp': 48.0, # Maximum displacement
|
468 |
+
}
|
469 |
+
},
|
470 |
+
|
471 |
+
},
|
472 |
+
'MipNeRF360': {
|
473 |
+
'garden': {
|
474 |
+
'up': [-1, 1],
|
475 |
+
'arc': {
|
476 |
+
'degree': 270.0, # Default arc degree
|
477 |
+
},
|
478 |
+
'spiral': {
|
479 |
+
'zrate': 0.5, # Spiral rise rate
|
480 |
+
'rots': 1, # Number of rotations
|
481 |
+
},
|
482 |
+
'lemniscate': {
|
483 |
+
'degree': 80.0, # Lemniscate curve angle
|
484 |
+
},
|
485 |
+
'wander': {
|
486 |
+
'max_disp': 48.0, # Maximum displacement
|
487 |
+
}
|
488 |
+
},
|
489 |
+
'kitchen': {
|
490 |
+
'up': [-1, 1],
|
491 |
+
'arc': {
|
492 |
+
'degree': 180.0, # Default arc degree
|
493 |
+
},
|
494 |
+
'spiral': {
|
495 |
+
'zrate': 0.5, # Spiral rise rate
|
496 |
+
'rots': 1, # Number of rotations
|
497 |
+
},
|
498 |
+
'lemniscate': {
|
499 |
+
'degree': 80.0, # Lemniscate curve angle
|
500 |
+
},
|
501 |
+
'wander': {
|
502 |
+
'max_disp': 48.0, # Maximum displacement
|
503 |
+
}
|
504 |
+
},
|
505 |
+
'room': {
|
506 |
+
'up': [-1, 1],
|
507 |
+
'arc': {
|
508 |
+
'degree': 180.0, # Default arc degree
|
509 |
+
},
|
510 |
+
'spiral': {
|
511 |
+
'zrate': 0.5, # Spiral rise rate
|
512 |
+
'rots': 1, # Number of rotations
|
513 |
+
},
|
514 |
+
'lemniscate': {
|
515 |
+
'degree': 60.0, # Lemniscate curve angle
|
516 |
+
},
|
517 |
+
'wander': {
|
518 |
+
'max_disp': 48.0, # Maximum displacement
|
519 |
+
}
|
520 |
+
},
|
521 |
+
},
|
522 |
+
'MVimgNet': {
|
523 |
+
'bench': {
|
524 |
+
'up': [-1, 1],
|
525 |
+
'arc': {
|
526 |
+
'degree': 180.0, # Default arc degree
|
527 |
+
},
|
528 |
+
'spiral': {
|
529 |
+
'zrate': 0.5, # Spiral rise rate
|
530 |
+
'rots': 1, # Number of rotations
|
531 |
+
},
|
532 |
+
'lemniscate': {
|
533 |
+
'degree': 80.0, # Lemniscate curve angle
|
534 |
+
},
|
535 |
+
'wander': {
|
536 |
+
'max_disp': 48.0, # Maximum displacement
|
537 |
+
}
|
538 |
+
},
|
539 |
+
'car': {
|
540 |
+
'up': [-1, 1],
|
541 |
+
'arc': {
|
542 |
+
'degree': 180.0, # Default arc degree
|
543 |
+
},
|
544 |
+
'spiral': {
|
545 |
+
'zrate': 0.5, # Spiral rise rate
|
546 |
+
'rots': 1, # Number of rotations
|
547 |
+
},
|
548 |
+
'lemniscate': {
|
549 |
+
'degree': 60.0, # Lemniscate curve angle
|
550 |
+
},
|
551 |
+
'wander': {
|
552 |
+
'max_disp': 48.0, # Maximum displacement
|
553 |
+
}
|
554 |
+
},
|
555 |
+
'suv': {
|
556 |
+
'up': [-1, 1],
|
557 |
+
'arc': {
|
558 |
+
'degree': 180.0, # Default arc degree
|
559 |
+
},
|
560 |
+
'spiral': {
|
561 |
+
'zrate': 0.5, # Spiral rise rate
|
562 |
+
'rots': 1, # Number of rotations
|
563 |
+
},
|
564 |
+
'lemniscate': {
|
565 |
+
'degree': 60.0, # Lemniscate curve angle
|
566 |
+
},
|
567 |
+
'wander': {
|
568 |
+
'max_disp': 48.0, # Maximum displacement
|
569 |
+
}
|
570 |
+
},
|
571 |
+
|
572 |
+
},
|
573 |
+
'LLFF': {
|
574 |
+
'fortress': {
|
575 |
+
'up': [-1, 1],
|
576 |
+
'arc': {
|
577 |
+
'degree': 180.0, # Default arc degree
|
578 |
+
},
|
579 |
+
'spiral': {
|
580 |
+
'zrate': 0.5, # Spiral rise rate
|
581 |
+
'rots': 1, # Number of rotations
|
582 |
+
},
|
583 |
+
'lemniscate': {
|
584 |
+
'degree': 30.0, # Lemniscate curve angle
|
585 |
+
},
|
586 |
+
'wander': {
|
587 |
+
'max_disp': 48.0, # Maximum displacement
|
588 |
+
}
|
589 |
+
},
|
590 |
+
'horns': {
|
591 |
+
'up': [-1, 1],
|
592 |
+
'arc': {
|
593 |
+
'degree': 180.0, # Default arc degree
|
594 |
+
},
|
595 |
+
'spiral': {
|
596 |
+
'zrate': 0.5, # Spiral rise rate
|
597 |
+
'rots': 1, # Number of rotations
|
598 |
+
},
|
599 |
+
'lemniscate': {
|
600 |
+
'degree': 60.0, # Lemniscate curve angle
|
601 |
+
},
|
602 |
+
'wander': {
|
603 |
+
'max_disp': 48.0, # Maximum displacement
|
604 |
+
}
|
605 |
+
},
|
606 |
+
'orchids': {
|
607 |
+
'up': [-1, 1],
|
608 |
+
'arc': {
|
609 |
+
'degree': 180.0, # Default arc degree
|
610 |
+
},
|
611 |
+
'spiral': {
|
612 |
+
'zrate': 0.5, # Spiral rise rate
|
613 |
+
'rots': 1, # Number of rotations
|
614 |
+
},
|
615 |
+
'lemniscate': {
|
616 |
+
'degree': 30.0, # Lemniscate curve angle
|
617 |
+
},
|
618 |
+
'wander': {
|
619 |
+
'max_disp': 48.0, # Maximum displacement
|
620 |
+
}
|
621 |
+
},
|
622 |
+
'room': {
|
623 |
+
'up': [-1, 1],
|
624 |
+
'arc': {
|
625 |
+
'degree': 180.0, # Default arc degree
|
626 |
+
},
|
627 |
+
'spiral': {
|
628 |
+
'zrate': 0.5, # Spiral rise rate
|
629 |
+
'rots': 1, # Number of rotations
|
630 |
+
},
|
631 |
+
'lemniscate': {
|
632 |
+
'degree': 30.0, # Lemniscate curve angle
|
633 |
+
},
|
634 |
+
'wander': {
|
635 |
+
'max_disp': 48.0, # Maximum displacement
|
636 |
+
}
|
637 |
+
},
|
638 |
+
'trex': {
|
639 |
+
'up': [-1, 1],
|
640 |
+
'arc': {
|
641 |
+
'degree': 180.0, # Default arc degree
|
642 |
+
},
|
643 |
+
'spiral': {
|
644 |
+
'zrate': 1, # Spiral rise rate
|
645 |
+
'rots': 1, # Number of rotations
|
646 |
+
},
|
647 |
+
'lemniscate': {
|
648 |
+
'degree': 30.0, # Lemniscate curve angle
|
649 |
+
},
|
650 |
+
'wander': {
|
651 |
+
'max_disp': 48.0, # Maximum displacement
|
652 |
+
}
|
653 |
+
},
|
654 |
+
},
|
655 |
+
}
|
utils/camera_utils.py
ADDED
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
from scene.cameras import Camera
|
13 |
+
import numpy as np
|
14 |
+
from utils.general_utils import PILtoTorch
|
15 |
+
from utils.graphics_utils import fov2focal, getWorld2View2
|
16 |
+
import scipy
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
from scipy.special import softmax
|
19 |
+
from typing import NamedTuple, List
|
20 |
+
|
21 |
+
WARNED = False
|
22 |
+
|
23 |
+
class CameraInfo(NamedTuple):
|
24 |
+
uid: int
|
25 |
+
R: np.array
|
26 |
+
T: np.array
|
27 |
+
FovY: np.array
|
28 |
+
FovX: np.array
|
29 |
+
image: np.array
|
30 |
+
image_path: str
|
31 |
+
image_name: str
|
32 |
+
width: int
|
33 |
+
height: int
|
34 |
+
|
35 |
+
|
36 |
+
def loadCam(args, id, cam_info, resolution_scale):
|
37 |
+
orig_w, orig_h = cam_info.image.size
|
38 |
+
|
39 |
+
if args.resolution in [1, 2, 4, 8]:
|
40 |
+
resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
|
41 |
+
else: # should be a type that converts to float
|
42 |
+
if args.resolution == -1:
|
43 |
+
if orig_w > 1600:
|
44 |
+
global WARNED
|
45 |
+
if not WARNED:
|
46 |
+
print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n "
|
47 |
+
"If this is not desired, please explicitly specify '--resolution/-r' as 1")
|
48 |
+
WARNED = True
|
49 |
+
global_down = orig_w / 1600
|
50 |
+
else:
|
51 |
+
global_down = 1
|
52 |
+
else:
|
53 |
+
global_down = orig_w / args.resolution
|
54 |
+
|
55 |
+
scale = float(global_down) * float(resolution_scale)
|
56 |
+
resolution = (int(orig_w / scale), int(orig_h / scale))
|
57 |
+
|
58 |
+
resized_image_rgb = PILtoTorch(cam_info.image, resolution)
|
59 |
+
|
60 |
+
gt_image = resized_image_rgb[:3, ...]
|
61 |
+
loaded_mask = None
|
62 |
+
|
63 |
+
if resized_image_rgb.shape[1] == 4:
|
64 |
+
loaded_mask = resized_image_rgb[3:4, ...]
|
65 |
+
|
66 |
+
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
|
67 |
+
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
|
68 |
+
image=gt_image, gt_alpha_mask=loaded_mask,
|
69 |
+
image_name=cam_info.image_name, uid=id, data_device=args.data_device)
|
70 |
+
|
71 |
+
|
72 |
+
def cameraList_from_camInfos(cam_infos, resolution_scale, args):
|
73 |
+
camera_list = []
|
74 |
+
|
75 |
+
for id, c in enumerate(cam_infos):
|
76 |
+
camera_list.append(loadCam(args, id, c, resolution_scale))
|
77 |
+
|
78 |
+
return camera_list
|
79 |
+
|
80 |
+
|
81 |
+
def camera_to_JSON(id, camera : Camera):
|
82 |
+
Rt = np.zeros((4, 4))
|
83 |
+
Rt[:3, :3] = camera.R.transpose()
|
84 |
+
Rt[:3, 3] = camera.T
|
85 |
+
Rt[3, 3] = 1.0
|
86 |
+
|
87 |
+
W2C = np.linalg.inv(Rt)
|
88 |
+
pos = W2C[:3, 3]
|
89 |
+
rot = W2C[:3, :3]
|
90 |
+
serializable_array_2d = [x.tolist() for x in rot]
|
91 |
+
camera_entry = {
|
92 |
+
'id' : id,
|
93 |
+
'img_name' : camera.image_name,
|
94 |
+
'width' : camera.width,
|
95 |
+
'height' : camera.height,
|
96 |
+
'position': pos.tolist(),
|
97 |
+
'rotation': serializable_array_2d,
|
98 |
+
'fy' : fov2focal(camera.FovY, camera.height),
|
99 |
+
'fx' : fov2focal(camera.FovX, camera.width)
|
100 |
+
}
|
101 |
+
return camera_entry
|
102 |
+
|
103 |
+
|
104 |
+
def transform_poses_pca(poses):
|
105 |
+
"""Transforms poses so principal components lie on XYZ axes.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
poses: a (N, 3, 4) array containing the cameras' camera to world transforms.
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
A tuple (poses, transform), with the transformed poses and the applied
|
112 |
+
camera_to_world transforms.
|
113 |
+
"""
|
114 |
+
t = poses[:, :3, 3]
|
115 |
+
t_mean = t.mean(axis=0)
|
116 |
+
t = t - t_mean
|
117 |
+
|
118 |
+
eigval, eigvec = np.linalg.eig(t.T @ t)
|
119 |
+
# Sort eigenvectors in order of largest to smallest eigenvalue.
|
120 |
+
inds = np.argsort(eigval)[::-1]
|
121 |
+
eigvec = eigvec[:, inds]
|
122 |
+
rot = eigvec.T
|
123 |
+
if np.linalg.det(rot) < 0:
|
124 |
+
rot = np.diag(np.array([1, 1, -1])) @ rot
|
125 |
+
|
126 |
+
transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1)
|
127 |
+
poses_recentered = unpad_poses(transform @ pad_poses(poses))
|
128 |
+
transform = np.concatenate([transform, np.eye(4)[3:]], axis=0)
|
129 |
+
|
130 |
+
# Flip coordinate system if z component of y-axis is negative
|
131 |
+
if poses_recentered.mean(axis=0)[2, 1] < 0:
|
132 |
+
poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered
|
133 |
+
transform = np.diag(np.array([1, -1, -1, 1])) @ transform
|
134 |
+
|
135 |
+
# Just make sure it's it in the [-1, 1]^3 cube
|
136 |
+
scale_factor = 1. / np.max(np.abs(poses_recentered[:, :3, 3]))
|
137 |
+
poses_recentered[:, :3, 3] *= scale_factor
|
138 |
+
# transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform
|
139 |
+
|
140 |
+
return poses_recentered, transform, scale_factor
|
141 |
+
|
142 |
+
def generate_interpolated_path(poses, n_interp, spline_degree=5,
|
143 |
+
smoothness=.03, rot_weight=.1):
|
144 |
+
"""Creates a smooth spline path between input keyframe camera poses.
|
145 |
+
|
146 |
+
Spline is calculated with poses in format (position, lookat-point, up-point).
|
147 |
+
|
148 |
+
Args:
|
149 |
+
poses: (n, 3, 4) array of input pose keyframes.
|
150 |
+
n_interp: returned path will have n_interp * (n - 1) total poses.
|
151 |
+
spline_degree: polynomial degree of B-spline.
|
152 |
+
smoothness: parameter for spline smoothing, 0 forces exact interpolation.
|
153 |
+
rot_weight: relative weighting of rotation/translation in spline solve.
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
Array of new camera poses with shape (n_interp * (n - 1), 3, 4).
|
157 |
+
"""
|
158 |
+
|
159 |
+
def poses_to_points(poses, dist):
|
160 |
+
"""Converts from pose matrices to (position, lookat, up) format."""
|
161 |
+
pos = poses[:, :3, -1]
|
162 |
+
lookat = poses[:, :3, -1] - dist * poses[:, :3, 2]
|
163 |
+
up = poses[:, :3, -1] + dist * poses[:, :3, 1]
|
164 |
+
return np.stack([pos, lookat, up], 1)
|
165 |
+
|
166 |
+
def points_to_poses(points):
|
167 |
+
"""Converts from (position, lookat, up) format to pose matrices."""
|
168 |
+
return np.array([viewmatrix(p - l, u - p, p) for p, l, u in points])
|
169 |
+
|
170 |
+
def interp(points, n, k, s):
|
171 |
+
"""Runs multidimensional B-spline interpolation on the input points."""
|
172 |
+
sh = points.shape
|
173 |
+
pts = np.reshape(points, (sh[0], -1))
|
174 |
+
k = min(k, sh[0] - 1)
|
175 |
+
tck, _ = scipy.interpolate.splprep(pts.T, k=k, s=s)
|
176 |
+
u = np.linspace(0, 1, n, endpoint=False)
|
177 |
+
new_points = np.array(scipy.interpolate.splev(u, tck))
|
178 |
+
new_points = np.reshape(new_points.T, (n, sh[1], sh[2]))
|
179 |
+
return new_points
|
180 |
+
|
181 |
+
### Additional operation
|
182 |
+
# inter_poses = []
|
183 |
+
# for pose in poses:
|
184 |
+
# tmp_pose = np.eye(4)
|
185 |
+
# tmp_pose[:3] = np.concatenate([pose.R.T, pose.T[:, None]], 1)
|
186 |
+
# tmp_pose = np.linalg.inv(tmp_pose)
|
187 |
+
# tmp_pose[:, 1:3] *= -1
|
188 |
+
# inter_poses.append(tmp_pose)
|
189 |
+
# inter_poses = np.stack(inter_poses, 0)
|
190 |
+
# poses, transform = transform_poses_pca(inter_poses)
|
191 |
+
|
192 |
+
points = poses_to_points(poses, dist=rot_weight)
|
193 |
+
new_points = interp(points,
|
194 |
+
n_interp * (points.shape[0] - 1),
|
195 |
+
k=spline_degree,
|
196 |
+
s=smoothness)
|
197 |
+
return points_to_poses(new_points)
|
198 |
+
|
199 |
+
def viewmatrix(lookdir, up, position):
|
200 |
+
"""Construct lookat view matrix."""
|
201 |
+
vec2 = normalize(lookdir)
|
202 |
+
vec0 = normalize(np.cross(up, vec2))
|
203 |
+
vec1 = normalize(np.cross(vec2, vec0))
|
204 |
+
m = np.stack([vec0, vec1, vec2, position], axis=1)
|
205 |
+
return m
|
206 |
+
|
207 |
+
def normalize(x):
|
208 |
+
"""Normalization helper function."""
|
209 |
+
return x / np.linalg.norm(x)
|
210 |
+
|
211 |
+
def pad_poses(p):
|
212 |
+
"""Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1]."""
|
213 |
+
bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape)
|
214 |
+
return np.concatenate([p[..., :3, :4], bottom], axis=-2)
|
215 |
+
|
216 |
+
|
217 |
+
def unpad_poses(p):
|
218 |
+
"""Remove the homogeneous bottom row from [..., 4, 4] pose matrices."""
|
219 |
+
return p[..., :3, :4]
|
220 |
+
|
221 |
+
def invert_transform_poses_pca(poses_recentered, transform, scale_factor):
|
222 |
+
poses_recentered[:, :3, 3] /= scale_factor
|
223 |
+
transform_inv = np.linalg.inv(transform)
|
224 |
+
poses_original = unpad_poses(transform_inv @ pad_poses(poses_recentered))
|
225 |
+
return poses_original
|
226 |
+
|
227 |
+
def visualizer(camera_poses, colors, save_path="/mnt/data/1.png"):
|
228 |
+
fig = plt.figure()
|
229 |
+
ax = fig.add_subplot(111, projection="3d")
|
230 |
+
|
231 |
+
for pose, color in zip(camera_poses, colors):
|
232 |
+
rotation = pose[:3, :3]
|
233 |
+
translation = pose[:3, 3] # Corrected to use 3D translation component
|
234 |
+
camera_positions = np.einsum(
|
235 |
+
"...ij,...j->...i", np.linalg.inv(rotation), -translation
|
236 |
+
)
|
237 |
+
|
238 |
+
ax.scatter(
|
239 |
+
camera_positions[0],
|
240 |
+
camera_positions[1],
|
241 |
+
camera_positions[2],
|
242 |
+
c=color,
|
243 |
+
marker="o",
|
244 |
+
)
|
245 |
+
|
246 |
+
ax.set_xlabel("X")
|
247 |
+
ax.set_ylabel("Y")
|
248 |
+
ax.set_zlabel("Z")
|
249 |
+
ax.set_title("Camera Poses")
|
250 |
+
|
251 |
+
plt.savefig(save_path)
|
252 |
+
plt.close()
|
253 |
+
|
254 |
+
return save_path
|
255 |
+
|
256 |
+
|
257 |
+
def focus_point_fn(poses: np.ndarray) -> np.ndarray:
|
258 |
+
"""Calculate nearest point to all focal axes in poses."""
|
259 |
+
directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4]
|
260 |
+
m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1])
|
261 |
+
mt_m = np.transpose(m, [0, 2, 1]) @ m
|
262 |
+
focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0]
|
263 |
+
return focus_pt
|
264 |
+
|
265 |
+
def interp(x, xp, fp):
|
266 |
+
# Flatten the input arrays
|
267 |
+
x_flat = x.reshape(-1, x.shape[-1])
|
268 |
+
xp_flat = xp.reshape(-1, xp.shape[-1])
|
269 |
+
fp_flat = fp.reshape(-1, fp.shape[-1])
|
270 |
+
|
271 |
+
# Perform interpolation for each set of flattened arrays
|
272 |
+
ret_flat = np.array([np.interp(xf, xpf, fpf) for xf, xpf, fpf in zip(x_flat, xp_flat, fp_flat)])
|
273 |
+
|
274 |
+
# Reshape the result to match the input shape
|
275 |
+
ret = ret_flat.reshape(x.shape)
|
276 |
+
return ret
|
277 |
+
|
278 |
+
def sorted_interp(x, xp, fp):
|
279 |
+
# Identify the location in `xp` that corresponds to each `x`.
|
280 |
+
# The final `True` index in `mask` is the start of the matching interval.
|
281 |
+
mask = x[..., None, :] >= xp[..., :, None]
|
282 |
+
|
283 |
+
def find_interval(x):
|
284 |
+
# Grab the value where `mask` switches from True to False, and vice versa.
|
285 |
+
# This approach takes advantage of the fact that `x` is sorted.
|
286 |
+
x0 = np.max(np.where(mask, x[..., None], x[..., :1, None]), -2)
|
287 |
+
x1 = np.min(np.where(~mask, x[..., None], x[..., -1:, None]), -2)
|
288 |
+
return x0, x1
|
289 |
+
|
290 |
+
fp0, fp1 = find_interval(fp)
|
291 |
+
xp0, xp1 = find_interval(xp)
|
292 |
+
with np.errstate(divide='ignore', invalid='ignore'):
|
293 |
+
offset = np.clip(np.nan_to_num((x - xp0) / (xp1 - xp0), nan=0.0), 0, 1)
|
294 |
+
ret = fp0 + offset * (fp1 - fp0)
|
295 |
+
return ret
|
296 |
+
|
297 |
+
def integrate_weights(w):
|
298 |
+
"""Compute the cumulative sum of w, assuming all weight vectors sum to 1.
|
299 |
+
|
300 |
+
The output's size on the last dimension is one greater than that of the input,
|
301 |
+
because we're computing the integral corresponding to the endpoints of a step
|
302 |
+
function, not the integral of the interior/bin values.
|
303 |
+
|
304 |
+
Args:
|
305 |
+
w: Tensor, which will be integrated along the last axis. This is assumed to
|
306 |
+
sum to 1 along the last axis, and this function will (silently) break if
|
307 |
+
that is not the case.
|
308 |
+
|
309 |
+
Returns:
|
310 |
+
cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1
|
311 |
+
"""
|
312 |
+
cw = np.minimum(1, np.cumsum(w[..., :-1], axis=-1))
|
313 |
+
shape = cw.shape[:-1] + (1,)
|
314 |
+
# Ensure that the CDF starts with exactly 0 and ends with exactly 1.
|
315 |
+
cw0 = np.concatenate([np.zeros(shape), cw, np.ones(shape)], axis=-1)
|
316 |
+
return cw0
|
317 |
+
|
318 |
+
def invert_cdf(u, t, w_logits, use_gpu_resampling=False):
|
319 |
+
"""Invert the CDF defined by (t, w) at the points specified by u in [0, 1)."""
|
320 |
+
# Compute the PDF and CDF for each weight vector.
|
321 |
+
w = softmax(w_logits, axis=-1)
|
322 |
+
cw = integrate_weights(w)
|
323 |
+
|
324 |
+
# Interpolate into the inverse CDF.
|
325 |
+
interp_fn = interp if use_gpu_resampling else sorted_interp # Assuming these are defined using NumPy
|
326 |
+
t_new = interp_fn(u, cw, t)
|
327 |
+
return t_new
|
328 |
+
|
329 |
+
def sample(rng,
|
330 |
+
t,
|
331 |
+
w_logits,
|
332 |
+
num_samples,
|
333 |
+
single_jitter=False,
|
334 |
+
deterministic_center=False,
|
335 |
+
use_gpu_resampling=False):
|
336 |
+
"""Piecewise-Constant PDF sampling from a step function.
|
337 |
+
|
338 |
+
Args:
|
339 |
+
rng: random number generator (or None for `linspace` sampling).
|
340 |
+
t: [..., num_bins + 1], bin endpoint coordinates (must be sorted)
|
341 |
+
w_logits: [..., num_bins], logits corresponding to bin weights
|
342 |
+
num_samples: int, the number of samples.
|
343 |
+
single_jitter: bool, if True, jitter every sample along each ray by the same
|
344 |
+
amount in the inverse CDF. Otherwise, jitter each sample independently.
|
345 |
+
deterministic_center: bool, if False, when `rng` is None return samples that
|
346 |
+
linspace the entire PDF. If True, skip the front and back of the linspace
|
347 |
+
so that the centers of each PDF interval are returned.
|
348 |
+
use_gpu_resampling: bool, If True this resamples the rays based on a
|
349 |
+
"gather" instruction, which is fast on GPUs but slow on TPUs. If False,
|
350 |
+
this resamples the rays based on brute-force searches, which is fast on
|
351 |
+
TPUs, but slow on GPUs.
|
352 |
+
|
353 |
+
Returns:
|
354 |
+
t_samples: jnp.ndarray(float32), [batch_size, num_samples].
|
355 |
+
"""
|
356 |
+
eps = np.finfo(np.float32).eps
|
357 |
+
|
358 |
+
# Draw uniform samples.
|
359 |
+
if rng is None:
|
360 |
+
# Match the behavior of jax.random.uniform() by spanning [0, 1-eps].
|
361 |
+
if deterministic_center:
|
362 |
+
pad = 1 / (2 * num_samples)
|
363 |
+
u = np.linspace(pad, 1. - pad - eps, num_samples)
|
364 |
+
else:
|
365 |
+
u = np.linspace(0, 1. - eps, num_samples)
|
366 |
+
u = np.broadcast_to(u, t.shape[:-1] + (num_samples,))
|
367 |
+
else:
|
368 |
+
# `u` is in [0, 1) --- it can be zero, but it can never be 1.
|
369 |
+
u_max = eps + (1 - eps) / num_samples
|
370 |
+
max_jitter = (1 - u_max) / (num_samples - 1) - eps
|
371 |
+
d = 1 if single_jitter else num_samples
|
372 |
+
u = (
|
373 |
+
np.linspace(0, 1 - u_max, num_samples) +
|
374 |
+
rng.uniform(size=t.shape[:-1] + (d,), high=max_jitter))
|
375 |
+
|
376 |
+
return invert_cdf(u, t, w_logits, use_gpu_resampling=use_gpu_resampling)
|
377 |
+
|
378 |
+
|
379 |
+
def generate_ellipse_path_from_poses(poses: np.ndarray,
|
380 |
+
n_frames: int = 120,
|
381 |
+
const_speed: bool = True,
|
382 |
+
z_variation: float = 0.,
|
383 |
+
z_phase: float = 0.) -> np.ndarray:
|
384 |
+
"""Generate an elliptical render path based on the given poses."""
|
385 |
+
# Calculate the focal point for the path (cameras point toward this).
|
386 |
+
center = focus_point_fn(poses)
|
387 |
+
# Path height sits at z=0 (in middle of zero-mean capture pattern).
|
388 |
+
offset = np.array([center[0], center[1], 0])
|
389 |
+
|
390 |
+
# Calculate scaling for ellipse axes based on input camera positions.
|
391 |
+
sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 100, axis=0)
|
392 |
+
# Use ellipse that is symmetric about the focal point in xy.
|
393 |
+
low = -sc + offset
|
394 |
+
high = sc + offset
|
395 |
+
# Optional height variation need not be symmetric
|
396 |
+
z_low = np.percentile((poses[:, :3, 3]), 0, axis=0)
|
397 |
+
z_high = np.percentile((poses[:, :3, 3]), 100, axis=0)
|
398 |
+
|
399 |
+
def get_positions(theta):
|
400 |
+
# Interpolate between bounds with trig functions to get ellipse in x-y.
|
401 |
+
# Optionally also interpolate in z to change camera height along path.
|
402 |
+
return np.stack([
|
403 |
+
low[0] + (high - low)[0] * (np.cos(theta) * .5 + .5),
|
404 |
+
low[1] + (high - low)[1] * (np.sin(theta) * .5 + .5),
|
405 |
+
z_variation * (z_low[2] + (z_high - z_low)[2] *
|
406 |
+
(np.cos(theta + 2 * np.pi * z_phase) * .5 + .5)),
|
407 |
+
], -1)
|
408 |
+
|
409 |
+
theta = np.linspace(0, 2. * np.pi, n_frames + 1, endpoint=True)
|
410 |
+
positions = get_positions(theta)
|
411 |
+
|
412 |
+
if const_speed:
|
413 |
+
# Resample theta angles so that the velocity is closer to constant.
|
414 |
+
lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
|
415 |
+
theta = sample(None, theta, np.log(lengths), n_frames + 1)
|
416 |
+
positions = get_positions(theta)
|
417 |
+
|
418 |
+
# Throw away duplicated last position.
|
419 |
+
positions = positions[:-1]
|
420 |
+
|
421 |
+
# Set path's up vector to axis closest to average of input pose up vectors.
|
422 |
+
avg_up = poses[:, :3, 1].mean(0)
|
423 |
+
avg_up = avg_up / np.linalg.norm(avg_up)
|
424 |
+
ind_up = np.argmax(np.abs(avg_up))
|
425 |
+
up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up])
|
426 |
+
|
427 |
+
return np.stack([viewmatrix(p - center, up, p) for p in positions])
|
428 |
+
|
429 |
+
def generate_ellipse_path_from_camera_infos(
|
430 |
+
cam_infos,
|
431 |
+
n_frames,
|
432 |
+
const_speed=False,
|
433 |
+
z_variation=0.,
|
434 |
+
z_phase=0.
|
435 |
+
):
|
436 |
+
print(f'Generating ellipse path from {len(cam_infos)} camera infos ...')
|
437 |
+
poses = np.array([np.linalg.inv(getWorld2View2(cam_info.R, cam_info.T))[:3, :4] for cam_info in cam_infos])
|
438 |
+
poses[:, :, 1:3] *= -1
|
439 |
+
poses, transform, scale_factor = transform_poses_pca(poses)
|
440 |
+
render_poses = generate_ellipse_path_from_poses(poses, n_frames, const_speed, z_variation, z_phase)
|
441 |
+
render_poses = invert_transform_poses_pca(render_poses, transform, scale_factor)
|
442 |
+
render_poses[:, :, 1:3] *= -1
|
443 |
+
ret_cam_infos = []
|
444 |
+
for uid, pose in enumerate(render_poses):
|
445 |
+
R = pose[:3, :3]
|
446 |
+
c2w = np.eye(4)
|
447 |
+
c2w[:3, :4] = pose
|
448 |
+
T = np.linalg.inv(c2w)[:3, 3]
|
449 |
+
cam_info = CameraInfo(
|
450 |
+
uid = uid,
|
451 |
+
R = R,
|
452 |
+
T = T,
|
453 |
+
FovY = cam_infos[0].FovY,
|
454 |
+
FovX = cam_infos[0].FovX,
|
455 |
+
# image = np.zeros_like(cam_infos[0].image),
|
456 |
+
image = cam_infos[0].image,
|
457 |
+
image_path = '',
|
458 |
+
image_name = f'{uid:05d}.png',
|
459 |
+
width = cam_infos[0].width,
|
460 |
+
height = cam_infos[0].height
|
461 |
+
)
|
462 |
+
ret_cam_infos.append(cam_info)
|
463 |
+
return ret_cam_infos
|
464 |
+
|
465 |
+
def generate_ellipse_path(
|
466 |
+
org_pose,
|
467 |
+
n_interp,
|
468 |
+
const_speed=False,
|
469 |
+
z_variation=0.,
|
470 |
+
z_phase=0.
|
471 |
+
):
|
472 |
+
print(f'Generating ellipse path from {len(org_pose)} camera infos ...')
|
473 |
+
poses = np.array([np.linalg.inv(p)[:3, :4] for p in org_pose]) # w2c >>> c2w
|
474 |
+
poses[:, :, 1:3] *= -1
|
475 |
+
poses, transform, scale_factor = transform_poses_pca(poses)
|
476 |
+
render_poses = generate_ellipse_path_from_poses(poses, n_interp, const_speed, z_variation, z_phase)
|
477 |
+
render_poses = invert_transform_poses_pca(render_poses, transform, scale_factor)
|
478 |
+
render_poses[:, :, 1:3] *= -1 # c2w
|
479 |
+
return render_poses
|
480 |
+
|
481 |
+
|
utils/dust3r_utils.py
ADDED
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import PIL.Image
|
6 |
+
from PIL.ImageOps import exif_transpose
|
7 |
+
from plyfile import PlyData, PlyElement
|
8 |
+
import torchvision.transforms as tvf
|
9 |
+
import roma
|
10 |
+
import dust3r.cloud_opt.init_im_poses as init_fun
|
11 |
+
from dust3r.cloud_opt.base_opt import global_alignment_loop
|
12 |
+
from dust3r.utils.geometry import geotrf, inv, depthmap_to_absolute_camera_coordinates
|
13 |
+
from dust3r.cloud_opt.commons import edge_str
|
14 |
+
from dust3r.utils.image import _resize_pil_image, imread_cv2
|
15 |
+
import dust3r.datasets.utils.cropping as cropping
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
def get_known_poses(scene):
|
19 |
+
if scene.has_im_poses:
|
20 |
+
known_poses_msk = torch.tensor([not (p.requires_grad) for p in scene.im_poses])
|
21 |
+
known_poses = scene.get_im_poses()
|
22 |
+
return known_poses_msk.sum(), known_poses_msk, known_poses
|
23 |
+
else:
|
24 |
+
return 0, None, None
|
25 |
+
|
26 |
+
def init_from_pts3d(scene, pts3d, im_focals, im_poses):
|
27 |
+
# init poses
|
28 |
+
nkp, known_poses_msk, known_poses = get_known_poses(scene)
|
29 |
+
if nkp == 1:
|
30 |
+
raise NotImplementedError("Would be simpler to just align everything afterwards on the single known pose")
|
31 |
+
elif nkp > 1:
|
32 |
+
# global rigid SE3 alignment
|
33 |
+
s, R, T = init_fun.align_multiple_poses(im_poses[known_poses_msk], known_poses[known_poses_msk])
|
34 |
+
trf = init_fun.sRT_to_4x4(s, R, T, device=known_poses.device)
|
35 |
+
|
36 |
+
# rotate everything
|
37 |
+
im_poses = trf @ im_poses
|
38 |
+
im_poses[:, :3, :3] /= s # undo scaling on the rotation part
|
39 |
+
for img_pts3d in pts3d:
|
40 |
+
img_pts3d[:] = geotrf(trf, img_pts3d)
|
41 |
+
|
42 |
+
# set all pairwise poses
|
43 |
+
for e, (i, j) in enumerate(scene.edges):
|
44 |
+
i_j = edge_str(i, j)
|
45 |
+
# compute transform that goes from cam to world
|
46 |
+
s, R, T = init_fun.rigid_points_registration(scene.pred_i[i_j], pts3d[i], conf=scene.conf_i[i_j])
|
47 |
+
scene._set_pose(scene.pw_poses, e, R, T, scale=s)
|
48 |
+
|
49 |
+
# take into account the scale normalization
|
50 |
+
s_factor = scene.get_pw_norm_scale_factor()
|
51 |
+
im_poses[:, :3, 3] *= s_factor # apply downscaling factor
|
52 |
+
for img_pts3d in pts3d:
|
53 |
+
img_pts3d *= s_factor
|
54 |
+
|
55 |
+
# init all image poses
|
56 |
+
if scene.has_im_poses:
|
57 |
+
for i in range(scene.n_imgs):
|
58 |
+
cam2world = im_poses[i]
|
59 |
+
depth = geotrf(inv(cam2world), pts3d[i])[..., 2]
|
60 |
+
scene._set_depthmap(i, depth)
|
61 |
+
scene._set_pose(scene.im_poses, i, cam2world)
|
62 |
+
if im_focals[i] is not None:
|
63 |
+
scene._set_focal(i, im_focals[i])
|
64 |
+
|
65 |
+
if scene.verbose:
|
66 |
+
print(' init loss =', float(scene()))
|
67 |
+
|
68 |
+
@torch.no_grad()
|
69 |
+
def init_minimum_spanning_tree(scene, focal_avg=False, known_focal=None, **kw):
|
70 |
+
""" Init all camera poses (image-wise and pairwise poses) given
|
71 |
+
an initial set of pairwise estimations.
|
72 |
+
"""
|
73 |
+
device = scene.device
|
74 |
+
pts3d, _, im_focals, im_poses = init_fun.minimum_spanning_tree(scene.imshapes, scene.edges,
|
75 |
+
scene.pred_i, scene.pred_j, scene.conf_i, scene.conf_j, scene.im_conf, scene.min_conf_thr,
|
76 |
+
device, has_im_poses=scene.has_im_poses, verbose=scene.verbose,
|
77 |
+
**kw)
|
78 |
+
|
79 |
+
if known_focal is not None:
|
80 |
+
repeat_focal = np.repeat(known_focal, len(im_focals))
|
81 |
+
for i in range(len(im_focals)):
|
82 |
+
im_focals[i] = known_focal
|
83 |
+
scene.preset_focal(known_focals=repeat_focal)
|
84 |
+
elif focal_avg:
|
85 |
+
im_focals_avg = np.array(im_focals).mean()
|
86 |
+
for i in range(len(im_focals)):
|
87 |
+
im_focals[i] = im_focals_avg
|
88 |
+
repeat_focal = np.array(im_focals)#.cpu().numpy()
|
89 |
+
scene.preset_focal(known_focals=repeat_focal)
|
90 |
+
|
91 |
+
return init_from_pts3d(scene, pts3d, im_focals, im_poses)
|
92 |
+
|
93 |
+
@torch.cuda.amp.autocast(enabled=False)
|
94 |
+
def compute_global_alignment(scene, init=None, niter_PnP=10, focal_avg=False, known_focal=None, **kw):
|
95 |
+
if init is None:
|
96 |
+
pass
|
97 |
+
elif init == 'msp' or init == 'mst':
|
98 |
+
init_minimum_spanning_tree(scene, niter_PnP=niter_PnP, focal_avg=focal_avg, known_focal=known_focal)
|
99 |
+
elif init == 'known_poses':
|
100 |
+
init_fun.init_from_known_poses(scene, min_conf_thr=scene.min_conf_thr,
|
101 |
+
niter_PnP=niter_PnP)
|
102 |
+
else:
|
103 |
+
raise ValueError(f'bad value for {init=}')
|
104 |
+
|
105 |
+
return global_alignment_loop(scene, **kw)
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
def load_images(folder_or_list, size, square_ok=False):
|
110 |
+
""" open and convert all images in a list or folder to proper input format for DUSt3R
|
111 |
+
"""
|
112 |
+
if isinstance(folder_or_list, str):
|
113 |
+
print(f'>> Loading images from {folder_or_list}')
|
114 |
+
root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
|
115 |
+
|
116 |
+
elif isinstance(folder_or_list, list):
|
117 |
+
print(f'>> Loading a list of {len(folder_or_list)} images')
|
118 |
+
root, folder_content = '', folder_or_list
|
119 |
+
|
120 |
+
else:
|
121 |
+
raise ValueError(f'bad {folder_or_list=} ({type(folder_or_list)})')
|
122 |
+
|
123 |
+
imgs = []
|
124 |
+
for path in folder_content:
|
125 |
+
if not path.endswith(('.jpg', '.jpeg', '.png', '.JPG')):
|
126 |
+
continue
|
127 |
+
img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert('RGB')
|
128 |
+
W1, H1 = img.size
|
129 |
+
if size == 224:
|
130 |
+
# resize short side to 224 (then crop)
|
131 |
+
img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1)))
|
132 |
+
else:
|
133 |
+
# resize long side to 512
|
134 |
+
img = _resize_pil_image(img, size)
|
135 |
+
W, H = img.size
|
136 |
+
W2 = W//16*16
|
137 |
+
H2 = H//16*16
|
138 |
+
img = np.array(img)
|
139 |
+
img = cv2.resize(img, (W2,H2), interpolation=cv2.INTER_LINEAR)
|
140 |
+
img = PIL.Image.fromarray(img)
|
141 |
+
|
142 |
+
print(f' - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}')
|
143 |
+
ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
144 |
+
imgs.append(dict(img=ImgNorm(img)[None], true_shape=np.int32(
|
145 |
+
[img.size[::-1]]), idx=len(imgs), instance=str(len(imgs))))
|
146 |
+
|
147 |
+
assert imgs, 'no images foud at '+root
|
148 |
+
print(f' (Found {len(imgs)} images)')
|
149 |
+
return imgs, (W1,H1)
|
150 |
+
|
151 |
+
|
152 |
+
def load_cam_mvsnet(file, interval_scale=1):
|
153 |
+
""" read camera txt file """
|
154 |
+
cam = np.zeros((2, 4, 4))
|
155 |
+
words = file.read().split()
|
156 |
+
# read extrinsic
|
157 |
+
for i in range(0, 4):
|
158 |
+
for j in range(0, 4):
|
159 |
+
extrinsic_index = 4 * i + j + 1
|
160 |
+
cam[0][i][j] = words[extrinsic_index]
|
161 |
+
|
162 |
+
# read intrinsic
|
163 |
+
for i in range(0, 3):
|
164 |
+
for j in range(0, 3):
|
165 |
+
intrinsic_index = 3 * i + j + 18
|
166 |
+
cam[1][i][j] = words[intrinsic_index]
|
167 |
+
|
168 |
+
if len(words) == 29:
|
169 |
+
cam[1][3][0] = words[27]
|
170 |
+
cam[1][3][1] = float(words[28]) * interval_scale
|
171 |
+
cam[1][3][2] = 192
|
172 |
+
cam[1][3][3] = cam[1][3][0] + cam[1][3][1] * cam[1][3][2]
|
173 |
+
elif len(words) == 30:
|
174 |
+
cam[1][3][0] = words[27]
|
175 |
+
cam[1][3][1] = float(words[28]) * interval_scale
|
176 |
+
cam[1][3][2] = words[29]
|
177 |
+
cam[1][3][3] = cam[1][3][0] + cam[1][3][1] * cam[1][3][2]
|
178 |
+
elif len(words) == 31:
|
179 |
+
cam[1][3][0] = words[27]
|
180 |
+
cam[1][3][1] = float(words[28]) * interval_scale
|
181 |
+
cam[1][3][2] = words[29]
|
182 |
+
cam[1][3][3] = words[30]
|
183 |
+
else:
|
184 |
+
cam[1][3][0] = 0
|
185 |
+
cam[1][3][1] = 0
|
186 |
+
cam[1][3][2] = 0
|
187 |
+
cam[1][3][3] = 0
|
188 |
+
|
189 |
+
|
190 |
+
extrinsic = cam[0].astype(np.float32)
|
191 |
+
intrinsic = cam[1].astype(np.float32)
|
192 |
+
|
193 |
+
return intrinsic, extrinsic
|
194 |
+
|
195 |
+
|
196 |
+
def _crop_resize_if_necessary(image, depthmap, intrinsics, resolution, rng=None, info=None):
|
197 |
+
""" This function:
|
198 |
+
- first downsizes the image with LANCZOS inteprolation,
|
199 |
+
which is better than bilinear interpolation in
|
200 |
+
"""
|
201 |
+
if not isinstance(image, PIL.Image.Image):
|
202 |
+
image = PIL.Image.fromarray(image)
|
203 |
+
|
204 |
+
# downscale with lanczos interpolation so that image.size == resolution
|
205 |
+
# cropping centered on the principal point
|
206 |
+
W, H = image.size
|
207 |
+
cx, cy = intrinsics[:2, 2].round().astype(int)
|
208 |
+
|
209 |
+
# calculate min distance to margin
|
210 |
+
min_margin_x = min(cx, W-cx)
|
211 |
+
min_margin_y = min(cy, H-cy)
|
212 |
+
assert min_margin_x > W/5, f'Bad principal point in view={info}'
|
213 |
+
assert min_margin_y > H/5, f'Bad principal point in view={info}'
|
214 |
+
|
215 |
+
## Center crop
|
216 |
+
# Crop on the principal point, make it always centered
|
217 |
+
# the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
|
218 |
+
l, t = cx - min_margin_x, cy - min_margin_y
|
219 |
+
r, b = cx + min_margin_x, cy + min_margin_y
|
220 |
+
crop_bbox = (l, t, r, b)
|
221 |
+
image, depthmap, intrinsics = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
|
222 |
+
|
223 |
+
# transpose the resolution if necessary
|
224 |
+
W, H = image.size # new size
|
225 |
+
assert resolution[0] >= resolution[1]
|
226 |
+
if H > 1.1*W:
|
227 |
+
# image is portrait mode
|
228 |
+
resolution = resolution[::-1]
|
229 |
+
elif 0.9 < H/W < 1.1 and resolution[0] != resolution[1]:
|
230 |
+
# image is square, so we chose (portrait, landscape) randomly
|
231 |
+
if rng.integers(2):
|
232 |
+
resolution = resolution[::-1]
|
233 |
+
|
234 |
+
# high-quality Lanczos down-scaling
|
235 |
+
target_resolution = np.array(resolution)
|
236 |
+
|
237 |
+
## Recale with max factor, so one of width or height might be larger than target_resolution
|
238 |
+
image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution)
|
239 |
+
|
240 |
+
# actual cropping (if necessary) with bilinear interpolation
|
241 |
+
intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=0.5)
|
242 |
+
crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution)
|
243 |
+
image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
|
244 |
+
|
245 |
+
return image, depthmap, intrinsics2
|
246 |
+
|
247 |
+
|
248 |
+
def load_images_dtu(folder_or_list, size, scene_folder):
|
249 |
+
"""
|
250 |
+
Preprocessing DTU requires depth, camera param and mask.
|
251 |
+
We follow Splatt3R to compute valid_mask.
|
252 |
+
"""
|
253 |
+
if isinstance(folder_or_list, str):
|
254 |
+
print(f'>> Loading images from {folder_or_list}')
|
255 |
+
root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
|
256 |
+
|
257 |
+
elif isinstance(folder_or_list, list):
|
258 |
+
print(f'>> Loading a list of {len(folder_or_list)} images')
|
259 |
+
root = os.path.dirname(folder_or_list[0]) if folder_or_list else ''
|
260 |
+
folder_content = [os.path.basename(p) for p in folder_or_list]
|
261 |
+
|
262 |
+
else:
|
263 |
+
raise ValueError(f'bad {folder_or_list=} ({type(folder_or_list)})')
|
264 |
+
|
265 |
+
depth_root = os.path.join(scene_folder, 'depths')
|
266 |
+
mask_root = os.path.join(scene_folder, 'binary_masks')
|
267 |
+
cam_root = os.path.join(scene_folder, 'cams')
|
268 |
+
|
269 |
+
imgs = []
|
270 |
+
for path in folder_content:
|
271 |
+
if not path.endswith(('.jpg', '.jpeg', '.png', '.JPG')):
|
272 |
+
continue
|
273 |
+
|
274 |
+
impath = os.path.join(root, path)
|
275 |
+
depthpath = os.path.join(depth_root, path.replace('.jpg', '.npy'))
|
276 |
+
campath = os.path.join(cam_root, path.replace('.jpg', '_cam.txt'))
|
277 |
+
maskpath = os.path.join(mask_root, path.replace('.jpg', '.png'))
|
278 |
+
|
279 |
+
rgb_image = imread_cv2(impath)
|
280 |
+
H1, W1 = rgb_image.shape[:2]
|
281 |
+
depthmap = np.load(depthpath)
|
282 |
+
depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0)
|
283 |
+
|
284 |
+
mask = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED)/255.0
|
285 |
+
mask = mask.astype(np.float32)
|
286 |
+
|
287 |
+
mask[mask>0.5] = 1.0
|
288 |
+
mask[mask<0.5] = 0.0
|
289 |
+
|
290 |
+
mask = cv2.resize(mask, (depthmap.shape[1], depthmap.shape[0]), interpolation=cv2.INTER_NEAREST)
|
291 |
+
kernel = np.ones((10, 10), np.uint8) # Define the erosion kernel
|
292 |
+
mask = cv2.erode(mask, kernel, iterations=1)
|
293 |
+
depthmap = depthmap * mask
|
294 |
+
|
295 |
+
cur_intrinsics, camera_pose = load_cam_mvsnet(open(campath, 'r'))
|
296 |
+
intrinsics = cur_intrinsics[:3, :3]
|
297 |
+
camera_pose = np.linalg.inv(camera_pose)
|
298 |
+
|
299 |
+
new_size = tuple(int(round(x*size/max(W1, H1))) for x in (W1, H1))
|
300 |
+
W, H = new_size
|
301 |
+
W2 = W//16*16
|
302 |
+
H2 = H//16*16
|
303 |
+
|
304 |
+
rgb_image, depthmap, intrinsics = _crop_resize_if_necessary(
|
305 |
+
rgb_image, depthmap, intrinsics, (W2, H2), info=impath)
|
306 |
+
|
307 |
+
print(f' - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}')
|
308 |
+
ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
309 |
+
|
310 |
+
img = dict(
|
311 |
+
img=ImgNorm(rgb_image)[None],
|
312 |
+
true_shape=np.int32([rgb_image.size[::-1]]),
|
313 |
+
idx=len(imgs),
|
314 |
+
instance=str(len(imgs)),
|
315 |
+
depthmap=depthmap,
|
316 |
+
camera_pose=camera_pose,
|
317 |
+
camera_intrinsics=intrinsics
|
318 |
+
)
|
319 |
+
|
320 |
+
pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**img)
|
321 |
+
img['pts3d'] = pts3d
|
322 |
+
img['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1)
|
323 |
+
|
324 |
+
imgs.append(img)
|
325 |
+
|
326 |
+
|
327 |
+
assert imgs, 'no images foud at '+root
|
328 |
+
print(f' (Found {len(imgs)} images)')
|
329 |
+
return imgs, (W1,H1)
|
330 |
+
|
331 |
+
|
332 |
+
def storePly(path, xyz, rgb, feat=None):
|
333 |
+
# Define the dtype for the structured array
|
334 |
+
dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
|
335 |
+
('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
|
336 |
+
('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
|
337 |
+
|
338 |
+
if feat is not None:
|
339 |
+
for i in range(feat.shape[1]):
|
340 |
+
dtype.append((f'feat_{i}', 'f4'))
|
341 |
+
|
342 |
+
normals = np.zeros_like(xyz)
|
343 |
+
|
344 |
+
elements = np.empty(xyz.shape[0], dtype=dtype)
|
345 |
+
attributes = np.concatenate((xyz, normals, rgb), axis=1)
|
346 |
+
|
347 |
+
if feat is not None:
|
348 |
+
attributes = np.concatenate((attributes, feat), axis=1)
|
349 |
+
|
350 |
+
elements[:] = list(map(tuple, attributes))
|
351 |
+
|
352 |
+
# Create the PlyData object and write to file
|
353 |
+
vertex_element = PlyElement.describe(elements, 'vertex')
|
354 |
+
ply_data = PlyData([vertex_element])
|
355 |
+
ply_data.write(path)
|
356 |
+
|
357 |
+
def R_to_quaternion(R):
|
358 |
+
"""
|
359 |
+
Convert a rotation matrix to a quaternion.
|
360 |
+
|
361 |
+
Parameters:
|
362 |
+
- R: A 3x3 numpy array representing a rotation matrix.
|
363 |
+
|
364 |
+
Returns:
|
365 |
+
- A numpy array representing the quaternion [w, x, y, z].
|
366 |
+
"""
|
367 |
+
m00, m01, m02 = R[0, 0], R[0, 1], R[0, 2]
|
368 |
+
m10, m11, m12 = R[1, 0], R[1, 1], R[1, 2]
|
369 |
+
m20, m21, m22 = R[2, 0], R[2, 1], R[2, 2]
|
370 |
+
trace = m00 + m11 + m22
|
371 |
+
|
372 |
+
if trace > 0:
|
373 |
+
s = 0.5 / np.sqrt(trace + 1.0)
|
374 |
+
w = 0.25 / s
|
375 |
+
x = (m21 - m12) * s
|
376 |
+
y = (m02 - m20) * s
|
377 |
+
z = (m10 - m01) * s
|
378 |
+
elif (m00 > m11) and (m00 > m22):
|
379 |
+
s = np.sqrt(1.0 + m00 - m11 - m22) * 2
|
380 |
+
w = (m21 - m12) / s
|
381 |
+
x = 0.25 * s
|
382 |
+
y = (m01 + m10) / s
|
383 |
+
z = (m02 + m20) / s
|
384 |
+
elif m11 > m22:
|
385 |
+
s = np.sqrt(1.0 + m11 - m00 - m22) * 2
|
386 |
+
w = (m02 - m20) / s
|
387 |
+
x = (m01 + m10) / s
|
388 |
+
y = 0.25 * s
|
389 |
+
z = (m12 + m21) / s
|
390 |
+
else:
|
391 |
+
s = np.sqrt(1.0 + m22 - m00 - m11) * 2
|
392 |
+
w = (m10 - m01) / s
|
393 |
+
x = (m02 + m20) / s
|
394 |
+
y = (m12 + m21) / s
|
395 |
+
z = 0.25 * s
|
396 |
+
|
397 |
+
return np.array([w, x, y, z])
|
398 |
+
|
399 |
+
def save_colmap_cameras(ori_size, intrinsics, camera_file):
|
400 |
+
with open(camera_file, 'w') as f:
|
401 |
+
for i, K in enumerate(intrinsics, 1): # Starting index at 1
|
402 |
+
width, height = ori_size
|
403 |
+
scale_factor_x = width/2 / K[0, 2]
|
404 |
+
scale_factor_y = height/2 / K[1, 2]
|
405 |
+
# assert scale_factor_x==scale_factor_y, "scale factor is not same for x and y"
|
406 |
+
# print(f'scale factor is not same for x {scale_factor_x} and y {scale_factor_y}')
|
407 |
+
f.write(f"{i} PINHOLE {width} {height} {K[0, 0]*scale_factor_x} {K[1, 1]*scale_factor_x} {width/2} {height/2}\n") # scale focal
|
408 |
+
# f.write(f"{i} PINHOLE {width} {height} {K[0, 0]} {K[1, 1]} {K[0, 2]} {K[1, 2]}\n")
|
409 |
+
|
410 |
+
def save_colmap_images(poses, images_file, train_img_list):
|
411 |
+
with open(images_file, 'w') as f:
|
412 |
+
for i, pose in enumerate(poses, 1): # Starting index at 1
|
413 |
+
# breakpoint()
|
414 |
+
pose = np.linalg.inv(pose)
|
415 |
+
R = pose[:3, :3]
|
416 |
+
t = pose[:3, 3]
|
417 |
+
q = R_to_quaternion(R) # Convert rotation matrix to quaternion
|
418 |
+
f.write(f"{i} {q[0]} {q[1]} {q[2]} {q[3]} {t[0]} {t[1]} {t[2]} {i} {os.path.basename(train_img_list[i-1])}\n")
|
419 |
+
f.write(f"\n")
|
420 |
+
|
421 |
+
|
422 |
+
def round_python3(number):
|
423 |
+
rounded = round(number)
|
424 |
+
if abs(number - rounded) == 0.5:
|
425 |
+
return 2.0 * round(number / 2.0)
|
426 |
+
return rounded
|
427 |
+
|
428 |
+
|
429 |
+
def rigid_points_registration(pts1, pts2, conf=None):
|
430 |
+
R, T, s = roma.rigid_points_registration(
|
431 |
+
pts1.reshape(-1, 3), pts2.reshape(-1, 3), weights=conf, compute_scaling=True)
|
432 |
+
return s, R, T # return un-scaled (R, T)
|
utils/feat_utils.py
ADDED
@@ -0,0 +1,827 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torchvision.transforms as tvf
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from dust3r.utils.device import to_numpy
|
8 |
+
|
9 |
+
from dust3r.inference import inference
|
10 |
+
from dust3r.model import AsymmetricCroCo3DStereo
|
11 |
+
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
|
12 |
+
from utils.dust3r_utils import compute_global_alignment
|
13 |
+
|
14 |
+
from mast3r.model import AsymmetricMASt3R
|
15 |
+
from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
|
16 |
+
from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
|
17 |
+
|
18 |
+
from hydra.utils import instantiate
|
19 |
+
from omegaconf import OmegaConf
|
20 |
+
|
21 |
+
|
22 |
+
class TorchPCA(object):
|
23 |
+
|
24 |
+
def __init__(self, n_components):
|
25 |
+
self.n_components = n_components
|
26 |
+
|
27 |
+
def fit(self, X):
|
28 |
+
self.mean_ = X.mean(dim=0)
|
29 |
+
unbiased = X - self.mean_.unsqueeze(0)
|
30 |
+
U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=50)
|
31 |
+
self.components_ = V.T
|
32 |
+
self.singular_values_ = S
|
33 |
+
return self
|
34 |
+
|
35 |
+
def transform(self, X):
|
36 |
+
t0 = X - self.mean_.unsqueeze(0)
|
37 |
+
projected = t0 @ self.components_.T
|
38 |
+
return projected
|
39 |
+
|
40 |
+
def pca(stacked_feat, dim):
|
41 |
+
flattened_feats = []
|
42 |
+
for feat in stacked_feat:
|
43 |
+
H, W, C = feat.shape
|
44 |
+
feat = feat.reshape(H * W, C).detach()
|
45 |
+
flattened_feats.append(feat)
|
46 |
+
x = torch.cat(flattened_feats, dim=0)
|
47 |
+
fit_pca = TorchPCA(n_components=dim).fit(x)
|
48 |
+
|
49 |
+
projected_feats = []
|
50 |
+
for feat in stacked_feat:
|
51 |
+
H, W, C = feat.shape
|
52 |
+
feat = feat.reshape(H * W, C).detach()
|
53 |
+
x_red = fit_pca.transform(feat)
|
54 |
+
projected_feats.append(x_red.reshape(H, W, dim))
|
55 |
+
projected_feats = torch.stack(projected_feats)
|
56 |
+
return projected_feats
|
57 |
+
|
58 |
+
|
59 |
+
def upsampler(feature, upsampled_height, upsampled_width, max_chunk=None):
|
60 |
+
"""
|
61 |
+
Upsample the feature tensor to the specified height and width.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
- feature (torch.Tensor): The input tensor with size [B, H, W, C].
|
65 |
+
- upsampled_height (int): The target height after upsampling.
|
66 |
+
- upsampled_width (int): The target width after upsampling.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
- upsampled_feature (torch.Tensor): The upsampled tensor with size [B, upsampled_height, upsampled_width, C].
|
70 |
+
"""
|
71 |
+
# Permute the tensor to [B, C, H, W] for interpolation
|
72 |
+
feature = feature.permute(0, 3, 1, 2)
|
73 |
+
|
74 |
+
# Perform the upsampling
|
75 |
+
if max_chunk:
|
76 |
+
upsampled_chunks = []
|
77 |
+
|
78 |
+
for i in range(0, len(feature), max_chunk):
|
79 |
+
chunk = feature[i:i+max_chunk]
|
80 |
+
|
81 |
+
upsampled_chunk = F.interpolate(chunk, size=(upsampled_height, upsampled_width), mode='bilinear', align_corners=False)
|
82 |
+
upsampled_chunks.append(upsampled_chunk)
|
83 |
+
|
84 |
+
upsampled_feature = torch.cat(upsampled_chunks, dim=0)
|
85 |
+
else:
|
86 |
+
upsampled_feature = F.interpolate(feature, size=(upsampled_height, upsampled_width), mode='bilinear', align_corners=False)
|
87 |
+
|
88 |
+
# Permute back to [B, H, W, C]
|
89 |
+
upsampled_feature = upsampled_feature.permute(0, 2, 3, 1)
|
90 |
+
|
91 |
+
return upsampled_feature
|
92 |
+
|
93 |
+
def visualizer(features, images, save_dir, dim=9, feat_type=None, file_name=None):
|
94 |
+
"""
|
95 |
+
Visualize features and corresponding images, and save the result.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
features (torch.Tensor): Feature tensor with shape [B, H, W, C].
|
99 |
+
images (list): List of dictionaries containing images with keys 'img'. Each image tensor has shape [1, 3, H, W]
|
100 |
+
and values in the range [-1, 1].
|
101 |
+
save_dir (str): Directory to save the resulting visualization.
|
102 |
+
feat_type (list): List of feature types.
|
103 |
+
file_name (str): Name of the file to save.
|
104 |
+
"""
|
105 |
+
import matplotlib
|
106 |
+
matplotlib.use('Agg')
|
107 |
+
from matplotlib import pyplot as plt
|
108 |
+
import torchvision.utils as vutils
|
109 |
+
|
110 |
+
assert features.dim() == 4, "Input tensor must have 4 dimensions (B, H, W, C)"
|
111 |
+
|
112 |
+
B, H, W, C = features.size()
|
113 |
+
|
114 |
+
features = features[..., dim-9:]
|
115 |
+
# Normalize the 3-dimensional feature to range [0, 1]
|
116 |
+
features_min = features.min(dim=0, keepdim=True).values.min(dim=1, keepdim=True).values.min(dim=2, keepdim=True).values
|
117 |
+
features_max = features.max(dim=0, keepdim=True).values.max(dim=1, keepdim=True).values.max(dim=2, keepdim=True).values
|
118 |
+
features = (features - features_min) / (features_max - features_min)
|
119 |
+
|
120 |
+
##### Save individual feature maps
|
121 |
+
# # Create subdirectory for feature visualizations
|
122 |
+
# feat_dir = os.path.join(save_dir, 'feature_maps')
|
123 |
+
# if feat_type:
|
124 |
+
# feat_dir = os.path.join(feat_dir, '-'.join(feat_type))
|
125 |
+
# os.makedirs(feat_dir, exist_ok=True)
|
126 |
+
|
127 |
+
# for i in range(B):
|
128 |
+
# # Extract and save the feature map (channels 3-6)
|
129 |
+
# feat_map = features[i, :, :, 3:6].permute(2, 0, 1) # [3, H, W]
|
130 |
+
# save_path = os.path.join(feat_dir, f'{i}_feat.png')
|
131 |
+
# vutils.save_image(feat_map, save_path, normalize=False)
|
132 |
+
|
133 |
+
# return feat_dir
|
134 |
+
|
135 |
+
##### Save feature maps in a single image
|
136 |
+
# Set the size of the plot
|
137 |
+
fig, axes = plt.subplots(B, 4, figsize=(W*4*0.01, H*B*0.01))
|
138 |
+
|
139 |
+
for i in range(B):
|
140 |
+
# Get the original image
|
141 |
+
image_tensor = images[i]['img']
|
142 |
+
assert image_tensor.dim() == 4 and image_tensor.size(0) == 1 and image_tensor.size(1) == 3, "Image tensor must have shape [1, 3, H, W]"
|
143 |
+
image = image_tensor.squeeze(0).permute(1, 2, 0).numpy() # Convert to (H, W, 3)
|
144 |
+
|
145 |
+
# Scale image values from [-1, 1] to [0, 1]
|
146 |
+
image = (image + 1) / 2
|
147 |
+
|
148 |
+
ax = axes[i, 0] if B > 1 else axes[0]
|
149 |
+
ax.imshow(image)
|
150 |
+
ax.axis('off')
|
151 |
+
|
152 |
+
# Visualize each 3-dimensional feature
|
153 |
+
for j in range(3):
|
154 |
+
ax = axes[i, j+1] if B > 1 else axes[j+1]
|
155 |
+
if j * 3 < min(C, dim): # Check if the feature channels are available
|
156 |
+
feature_to_plot = features[i, :, :, j*3:(j+1)*3].cpu().numpy()
|
157 |
+
ax.imshow(feature_to_plot)
|
158 |
+
else: # Plot white image if features are not available
|
159 |
+
ax.imshow(torch.ones(H, W, 3).numpy())
|
160 |
+
ax.axis('off')
|
161 |
+
|
162 |
+
# Reduce margins and spaces between images
|
163 |
+
plt.subplots_adjust(wspace=0.005, hspace=0.005, left=0.01, right=0.99, top=0.99, bottom=0.01)
|
164 |
+
|
165 |
+
# Save the entire plot
|
166 |
+
if file_name is None:
|
167 |
+
file_name = f'feat_dim{dim-9}-{dim}'
|
168 |
+
if feat_type:
|
169 |
+
feat_type_str = '-'.join(feat_type)
|
170 |
+
file_name = file_name + f'_{feat_type_str}'
|
171 |
+
save_path = os.path.join(save_dir, file_name + '.png')
|
172 |
+
plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
|
173 |
+
plt.close()
|
174 |
+
|
175 |
+
return save_path
|
176 |
+
|
177 |
+
|
178 |
+
#### Open it if you visualize feature maps in Feat2GS's teaser
|
179 |
+
# import matplotlib.colors as mcolors
|
180 |
+
# from PIL import Image
|
181 |
+
|
182 |
+
# morandi_colors = [
|
183 |
+
# '#8AA2A9', '#C98474', '#F2D0A9', '#8D9F87', '#A7A7A7', '#D98E73', '#B24C33', '#5E7460', '#4A6B8A', '#B2CBC2',
|
184 |
+
# '#BBC990', '#6B859E', '#B45342', '#4E0000', '#3D0000', '#2C0000', '#1B0000', '#0A0000', '#DCAC99', '#6F936B',
|
185 |
+
# '#EBA062', '#FED273', '#9A8EB4', '#706052', '#E9E5E5', '#C4D8D2', '#F2CBBD', '#F6F9F1', '#C5CABC', '#A3968B',
|
186 |
+
# '#5C6974', '#BE7B6E', '#C67752', '#C18830', '#8C956C', '#CAC691', '#819992', '#4D797F', '#95AEB2', '#B6C4CF',
|
187 |
+
# '#84291C', '#B9551F', '#A96400', '#374B6C', '#C8B493', '#677D5D', '#9882A2', '#2D5F53', '#D2A0AC', '#658D9A',
|
188 |
+
# '#9A7265', '#EFE1D2', '#DDD8D1', '#D2C6BC', '#E3C9BC', '#B8AB9F', '#D8BEA4', '#E0D4C5', '#B8B8B6', '#D0CAC3',
|
189 |
+
# '#9AA8B5', '#BBC9B9', '#E3E8D8', '#ADB3A4', '#C5C9BB', '#A3968B', '#C2A995', '#EDE1D1', '#EDE8E1', '#EDEBE1',
|
190 |
+
# '#CFCFCC', '#AABAC6', '#DCDEE0', '#EAE5E7', '#B7AB9F', '#F7EFE3', '#DED8CF', '#ABCA99', '#C5CD8F', '#959491',
|
191 |
+
# '#FFE481', '#C18E99', '#B07C86', '#9F6A73', '#8E5860', '#DEAD44', '#CD9B31', '#BC891E', '#AB770B', '#9A6500',
|
192 |
+
# '#778144', '#666F31', '#555D1E', '#444B0B', '#333900', '#67587B', '#564668', '#684563', '#573350', '#684550',
|
193 |
+
# '#57333D', '#46212A', '#350F17', '#240004',
|
194 |
+
# ]
|
195 |
+
|
196 |
+
# def rgb_to_hsv(rgb):
|
197 |
+
# rgb = rgb.clamp(0, 1)
|
198 |
+
|
199 |
+
# cmax, cmax_idx = rgb.max(dim=-1)
|
200 |
+
# cmin = rgb.min(dim=-1).values
|
201 |
+
|
202 |
+
# diff = cmax - cmin
|
203 |
+
|
204 |
+
# h = torch.zeros_like(cmax)
|
205 |
+
# h[cmax_idx == 0] = (((rgb[..., 1] - rgb[..., 2]) / diff) % 6)[cmax_idx == 0]
|
206 |
+
# h[cmax_idx == 1] = (((rgb[..., 2] - rgb[..., 0]) / diff) + 2)[cmax_idx == 1]
|
207 |
+
# h[cmax_idx == 2] = (((rgb[..., 0] - rgb[..., 1]) / diff) + 4)[cmax_idx == 2]
|
208 |
+
# h[diff == 0] = 0 # If cmax == cmin
|
209 |
+
# h = h / 6
|
210 |
+
|
211 |
+
# s = torch.zeros_like(cmax)
|
212 |
+
# s[cmax != 0] = (diff / cmax)[cmax != 0]
|
213 |
+
|
214 |
+
# v = cmax
|
215 |
+
|
216 |
+
# return torch.stack([h, s, v], dim=-1)
|
217 |
+
|
218 |
+
# def hsv_to_rgb(hsv):
|
219 |
+
# h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]
|
220 |
+
|
221 |
+
# c = v * s
|
222 |
+
# x = c * (1 - torch.abs((h * 6) % 2 - 1))
|
223 |
+
# m = v - c
|
224 |
+
|
225 |
+
# rgb = torch.zeros_like(hsv)
|
226 |
+
# mask = (h < 1/6)
|
227 |
+
# rgb[mask] = torch.stack([c[mask], x[mask], torch.zeros_like(x[mask])], dim=-1)
|
228 |
+
# mask = (1/6 <= h) & (h < 2/6)
|
229 |
+
# rgb[mask] = torch.stack([x[mask], c[mask], torch.zeros_like(x[mask])], dim=-1)
|
230 |
+
# mask = (2/6 <= h) & (h < 3/6)
|
231 |
+
# rgb[mask] = torch.stack([torch.zeros_like(x[mask]), c[mask], x[mask]], dim=-1)
|
232 |
+
# mask = (3/6 <= h) & (h < 4/6)
|
233 |
+
# rgb[mask] = torch.stack([torch.zeros_like(x[mask]), x[mask], c[mask]], dim=-1)
|
234 |
+
# mask = (4/6 <= h) & (h < 5/6)
|
235 |
+
# rgb[mask] = torch.stack([x[mask], torch.zeros_like(x[mask]), c[mask]], dim=-1)
|
236 |
+
# mask = (5/6 <= h)
|
237 |
+
# rgb[mask] = torch.stack([c[mask], torch.zeros_like(x[mask]), x[mask]], dim=-1)
|
238 |
+
|
239 |
+
# return rgb + m.unsqueeze(-1)
|
240 |
+
|
241 |
+
# def interpolate_colors(colors, n_colors):
|
242 |
+
# # Convert colors to RGB tensor
|
243 |
+
# rgb_colors = torch.tensor([mcolors.to_rgb(color) for color in colors])
|
244 |
+
|
245 |
+
# # Convert RGB to HSV
|
246 |
+
# hsv_colors = rgb_to_hsv(rgb_colors)
|
247 |
+
|
248 |
+
# # Sort by hue
|
249 |
+
# sorted_indices = torch.argsort(hsv_colors[:, 0])
|
250 |
+
# sorted_hsv_colors = hsv_colors[sorted_indices]
|
251 |
+
|
252 |
+
# # Create interpolation indices
|
253 |
+
# indices = torch.linspace(0, len(sorted_hsv_colors) - 1, n_colors)
|
254 |
+
|
255 |
+
# # Perform interpolation
|
256 |
+
# interpolated_hsv = torch.stack([
|
257 |
+
# torch.lerp(sorted_hsv_colors[int(i)],
|
258 |
+
# sorted_hsv_colors[min(int(i) + 1, len(sorted_hsv_colors) - 1)],
|
259 |
+
# i - int(i))
|
260 |
+
# for i in indices
|
261 |
+
# ])
|
262 |
+
|
263 |
+
# # Convert interpolated result back to RGB
|
264 |
+
# interpolated_rgb = hsv_to_rgb(interpolated_hsv)
|
265 |
+
|
266 |
+
# return interpolated_rgb
|
267 |
+
|
268 |
+
|
269 |
+
# def project_to_morandi(features, morandi_colors):
|
270 |
+
# features_flat = features.reshape(-1, 3)
|
271 |
+
# distances = torch.cdist(features_flat, morandi_colors)
|
272 |
+
|
273 |
+
# # Get the indices of the closest colors
|
274 |
+
# closest_color_indices = torch.argmin(distances, dim=1)
|
275 |
+
|
276 |
+
# # Use the closest Morandi colors directly
|
277 |
+
# features_morandi = morandi_colors[closest_color_indices]
|
278 |
+
|
279 |
+
# features_morandi = features_morandi.reshape(features.shape)
|
280 |
+
# return features_morandi
|
281 |
+
|
282 |
+
|
283 |
+
# def smooth_color_transform(features, morandi_colors, smoothness=0.1):
|
284 |
+
# features_flat = features.reshape(-1, 3)
|
285 |
+
# distances = torch.cdist(features_flat, morandi_colors)
|
286 |
+
|
287 |
+
# # Calculate weights
|
288 |
+
# weights = torch.exp(-distances / smoothness)
|
289 |
+
# weights = weights / weights.sum(dim=1, keepdim=True)
|
290 |
+
|
291 |
+
# # Weighted average
|
292 |
+
# features_morandi = torch.matmul(weights, morandi_colors)
|
293 |
+
|
294 |
+
# features_morandi = features_morandi.reshape(features.shape)
|
295 |
+
# return features_morandi
|
296 |
+
|
297 |
+
# def histogram_matching(source, template):
|
298 |
+
# """
|
299 |
+
# Match the histogram of the source tensor to that of the template tensor.
|
300 |
+
|
301 |
+
# :param source: Source tensor with shape [B, H, W, 3]
|
302 |
+
# :param template: Template tensor with shape [N, 3], where N is the number of colors
|
303 |
+
# :return: Source tensor after histogram matching
|
304 |
+
# """
|
305 |
+
# def match_cumulative_cdf(source, template):
|
306 |
+
# src_values, src_indices, src_counts = torch.unique(source, return_inverse=True, return_counts=True)
|
307 |
+
# tmpl_values, tmpl_counts = torch.unique(template, return_counts=True)
|
308 |
+
|
309 |
+
# src_quantiles = torch.cumsum(src_counts.float(), 0) / source.numel()
|
310 |
+
# tmpl_quantiles = torch.cumsum(tmpl_counts.float(), 0) / template.numel()
|
311 |
+
|
312 |
+
# idx = torch.searchsorted(tmpl_quantiles, src_quantiles)
|
313 |
+
# idx = torch.clamp(idx, 1, len(tmpl_quantiles)-1)
|
314 |
+
|
315 |
+
# slope = (tmpl_values[idx] - tmpl_values[idx-1]) / (tmpl_quantiles[idx] - tmpl_quantiles[idx-1])
|
316 |
+
# interp_a_values = torch.lerp(tmpl_values[idx-1], tmpl_values[idx],
|
317 |
+
# (src_quantiles - tmpl_quantiles[idx-1]) * slope)
|
318 |
+
|
319 |
+
# return interp_a_values[src_indices].reshape(source.shape)
|
320 |
+
|
321 |
+
# matched = torch.stack([match_cumulative_cdf(source[..., i], template[:, i]) for i in range(3)], dim=-1)
|
322 |
+
# return matched
|
323 |
+
|
324 |
+
# def process_features(features):
|
325 |
+
# device = features.device
|
326 |
+
|
327 |
+
# n_colors = 1024
|
328 |
+
# morandi_colors_tensor = interpolate_colors(morandi_colors, n_colors).to(device)
|
329 |
+
# # morandi_colors_tensor = torch.tensor([mcolors.to_rgb(color) for color in morandi_colors]).to(device)
|
330 |
+
|
331 |
+
# # features_morandi = project_to_morandi(features, morandi_colors_tensor)
|
332 |
+
# # features_morandi = histogram_matching(features, morandi_colors_tensor)
|
333 |
+
# features_morandi = smooth_color_transform(features, morandi_colors_tensor, smoothness=0.05)
|
334 |
+
|
335 |
+
# return features_morandi.cpu().numpy()
|
336 |
+
|
337 |
+
# def visualizer(features, images, save_dir, dim=9, feat_type=None, file_name=None):
|
338 |
+
# import matplotlib
|
339 |
+
# matplotlib.use('Agg')
|
340 |
+
|
341 |
+
# import matplotlib.pyplot as plt
|
342 |
+
# import numpy as np
|
343 |
+
# import os
|
344 |
+
|
345 |
+
# assert features.dim() == 4, "Input tensor must have 4 dimensions (B, H, W, C)"
|
346 |
+
|
347 |
+
# B, H, W, C = features.size()
|
348 |
+
|
349 |
+
# # Ensure features have at least 3 channels for RGB visualization
|
350 |
+
# assert C >= 3, "Features must have at least 3 channels for RGB visualization"
|
351 |
+
# features = features[..., :3]
|
352 |
+
|
353 |
+
# # Normalize features to [0, 1] range
|
354 |
+
# features_min = features.min(dim=0, keepdim=True).values.min(dim=1, keepdim=True).values.min(dim=2, keepdim=True).values
|
355 |
+
# features_max = features.max(dim=0, keepdim=True).values.max(dim=1, keepdim=True).values.max(dim=2, keepdim=True).values
|
356 |
+
# features = (features - features_min) / (features_max - features_min)
|
357 |
+
|
358 |
+
# features_processed = process_features(features)
|
359 |
+
|
360 |
+
# # Create the directory structure
|
361 |
+
# vis_dir = os.path.join(save_dir, 'vis')
|
362 |
+
|
363 |
+
# if feat_type:
|
364 |
+
# feat_type_str = '-'.join(feat_type)
|
365 |
+
# vis_dir = os.path.join(vis_dir, feat_type_str)
|
366 |
+
# os.makedirs(vis_dir, exist_ok=True)
|
367 |
+
|
368 |
+
# # Save individual images for each feature map
|
369 |
+
# for i in range(B):
|
370 |
+
# if file_name is None:
|
371 |
+
# file_name = 'feat_morandi'
|
372 |
+
# save_path = os.path.join(vis_dir, f'{file_name}_{i}.png')
|
373 |
+
|
374 |
+
# # Convert to uint8 and save directly
|
375 |
+
# img = Image.fromarray((features_processed[i] * 255).astype(np.uint8))
|
376 |
+
# img.save(save_path)
|
377 |
+
|
378 |
+
# print(f"Feature maps have been saved in the directory: {vis_dir}")
|
379 |
+
# return vis_dir
|
380 |
+
|
381 |
+
def mv_visualizer(features, images, save_dir, dim=9, feat_type=None, file_name=None):
|
382 |
+
"""
|
383 |
+
Visualize features and corresponding images, and save the result. (For MASt3R decoder or head features)
|
384 |
+
"""
|
385 |
+
import matplotlib
|
386 |
+
matplotlib.use('Agg')
|
387 |
+
from matplotlib import pyplot as plt
|
388 |
+
import os
|
389 |
+
|
390 |
+
B, H, W, _ = features.size()
|
391 |
+
features = features[..., dim-9:]
|
392 |
+
|
393 |
+
# Normalize the 3-dimensional feature to range [0, 1]
|
394 |
+
features_min = features.min(dim=0, keepdim=True).values.min(dim=1, keepdim=True).values.min(dim=2, keepdim=True).values
|
395 |
+
features_max = features.max(dim=0, keepdim=True).values.max(dim=1, keepdim=True).values.max(dim=2, keepdim=True).values
|
396 |
+
features = (features - features_min) / (features_max - features_min)
|
397 |
+
|
398 |
+
rows = (B + 1) // 2 # Adjust rows for odd B
|
399 |
+
fig, axes = plt.subplots(rows, 8, figsize=(W*8*0.01, H*rows*0.01))
|
400 |
+
|
401 |
+
for i in range(B//2):
|
402 |
+
# Odd row: image and features
|
403 |
+
image = (images[2*i]['img'].squeeze(0).permute(1, 2, 0).numpy() + 1) / 2
|
404 |
+
axes[i, 0].imshow(image)
|
405 |
+
axes[i, 0].axis('off')
|
406 |
+
for j in range(3):
|
407 |
+
axes[i, j+1].imshow(features[2*i, :, :, j*3:(j+1)*3].cpu().numpy())
|
408 |
+
axes[i, j+1].axis('off')
|
409 |
+
|
410 |
+
# Even row: image and features
|
411 |
+
if 2*i + 1 < B:
|
412 |
+
image = (images[2*i + 1]['img'].squeeze(0).permute(1, 2, 0).numpy() + 1) / 2
|
413 |
+
axes[i, 4].imshow(image)
|
414 |
+
axes[i, 4].axis('off')
|
415 |
+
for j in range(3):
|
416 |
+
axes[i, j+5].imshow(features[2*i + 1, :, :, j*3:(j+1)*3].cpu().numpy())
|
417 |
+
axes[i, j+5].axis('off')
|
418 |
+
|
419 |
+
# Handle last row if B is odd
|
420 |
+
if B % 2 != 0:
|
421 |
+
image = (images[-1]['img'].squeeze(0).permute(1, 2, 0).numpy() + 1) / 2
|
422 |
+
axes[-1, 0].imshow(image)
|
423 |
+
axes[-1, 0].axis('off')
|
424 |
+
for j in range(3):
|
425 |
+
axes[-1, j+1].imshow(features[-1, :, :, j*3:(j+1)*3].cpu().numpy())
|
426 |
+
axes[-1, j+1].axis('off')
|
427 |
+
|
428 |
+
# Hide unused columns in last row
|
429 |
+
for j in range(4, 8):
|
430 |
+
axes[-1, j].axis('off')
|
431 |
+
|
432 |
+
plt.subplots_adjust(wspace=0.005, hspace=0.005, left=0.01, right=0.99, top=0.99, bottom=0.01)
|
433 |
+
|
434 |
+
# Save the plot
|
435 |
+
if file_name is None:
|
436 |
+
file_name = f'feat_dim{dim-9}-{dim}'
|
437 |
+
if feat_type:
|
438 |
+
feat_type_str = '-'.join(feat_type)
|
439 |
+
file_name = file_name + f'_{feat_type_str}'
|
440 |
+
save_path = os.path.join(save_dir, file_name + '.png')
|
441 |
+
plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
|
442 |
+
plt.close()
|
443 |
+
|
444 |
+
return save_path
|
445 |
+
|
446 |
+
|
447 |
+
def adjust_norm(image: torch.Tensor) -> torch.Tensor:
|
448 |
+
|
449 |
+
inv_normalize = tvf.Normalize(
|
450 |
+
mean=[-1, -1, -1],
|
451 |
+
std=[1/0.5, 1/0.5, 1/0.5]
|
452 |
+
)
|
453 |
+
|
454 |
+
correct_normalize = tvf.Normalize(
|
455 |
+
mean=[0.485, 0.456, 0.406],
|
456 |
+
std=[0.229, 0.224, 0.225]
|
457 |
+
)
|
458 |
+
image = inv_normalize(image)
|
459 |
+
image = correct_normalize(image)
|
460 |
+
|
461 |
+
return image
|
462 |
+
|
463 |
+
def adjust_midas_norm(image: torch.Tensor) -> torch.Tensor:
|
464 |
+
|
465 |
+
inv_normalize = tvf.Normalize(
|
466 |
+
mean=[-1, -1, -1],
|
467 |
+
std=[1/0.5, 1/0.5, 1/0.5]
|
468 |
+
)
|
469 |
+
|
470 |
+
correct_normalize = tvf.Normalize(
|
471 |
+
mean=[0.5, 0.5, 0.5],
|
472 |
+
std=[0.5, 0.5, 0.5]
|
473 |
+
)
|
474 |
+
|
475 |
+
image = inv_normalize(image)
|
476 |
+
image = correct_normalize(image)
|
477 |
+
|
478 |
+
return image
|
479 |
+
|
480 |
+
def adjust_clip_norm(image: torch.Tensor) -> torch.Tensor:
|
481 |
+
|
482 |
+
inv_normalize = tvf.Normalize(
|
483 |
+
mean=[-1, -1, -1],
|
484 |
+
std=[1/0.5, 1/0.5, 1/0.5]
|
485 |
+
)
|
486 |
+
|
487 |
+
correct_normalize = tvf.Normalize(
|
488 |
+
mean=[0.48145466, 0.4578275, 0.40821073],
|
489 |
+
std=[0.26862954, 0.26130258, 0.27577711]
|
490 |
+
)
|
491 |
+
|
492 |
+
image = inv_normalize(image)
|
493 |
+
image = correct_normalize(image)
|
494 |
+
|
495 |
+
return image
|
496 |
+
|
497 |
+
class UnNormalize(object):
|
498 |
+
def __init__(self, mean, std):
|
499 |
+
self.mean = mean
|
500 |
+
self.std = std
|
501 |
+
|
502 |
+
def __call__(self, image):
|
503 |
+
image2 = torch.clone(image)
|
504 |
+
if len(image2.shape) == 4:
|
505 |
+
image2 = image2.permute(1, 0, 2, 3)
|
506 |
+
for t, m, s in zip(image2, self.mean, self.std):
|
507 |
+
t.mul_(s).add_(m)
|
508 |
+
return image2.permute(1, 0, 2, 3)
|
509 |
+
|
510 |
+
|
511 |
+
norm = tvf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
512 |
+
unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
513 |
+
|
514 |
+
midas_norm = tvf.Normalize([0.5] * 3, [0.5] * 3)
|
515 |
+
midas_unnorm = UnNormalize([0.5] * 3, [0.5] * 3)
|
516 |
+
|
517 |
+
|
518 |
+
def generate_iuv(B, H, W):
|
519 |
+
i_coords = torch.arange(B).view(B, 1, 1, 1).expand(B, H, W, 1).float() / (B - 1)
|
520 |
+
u_coords = torch.linspace(0, 1, W).view(1, 1, W, 1).expand(B, H, W, 1)
|
521 |
+
v_coords = torch.linspace(0, 1, H).view(1, H, 1, 1).expand(B, H, W, 1)
|
522 |
+
iuv_coords = torch.cat([i_coords, u_coords, v_coords], dim=-1)
|
523 |
+
return iuv_coords
|
524 |
+
|
525 |
+
class FeatureExtractor:
|
526 |
+
"""
|
527 |
+
Extracts and processes features from images using VFMs for per point(per pixel).
|
528 |
+
Supports multiple VFM features, dimensionality reduction, feature upsampling, and visualization.
|
529 |
+
|
530 |
+
Parameters:
|
531 |
+
images (list): List of image info.
|
532 |
+
method (str): Pointmap Init method, choose in ["dust3r", "mast3r"].
|
533 |
+
device (str): 'cuda'.
|
534 |
+
feat_type (list): VFM, choose in ["dust3r", "mast3r", "dift", "dino_b16", "dinov2_b14", "radio", "clip_b16", "mae_b16", "midas_l16", "sam_base", "iuvrgb"].
|
535 |
+
feat_dim (int): PCA dimensions.
|
536 |
+
img_base_path (str): Training view data directory path.
|
537 |
+
model_path (str): Model path, './submodules/mast3r/checkpoints/'.
|
538 |
+
vis_feat (bool): Visualize and save feature maps.
|
539 |
+
vis_key (str): Feature type to visualize(only for mast3r), choose in ["decfeat", "desc"].
|
540 |
+
focal_avg (bool): Use averaging focal.
|
541 |
+
"""
|
542 |
+
def __init__(self, images, args, method):
|
543 |
+
self.images = images
|
544 |
+
self.method = method
|
545 |
+
self.device = args.device
|
546 |
+
self.feat_type = args.feat_type
|
547 |
+
self.feat_dim = args.feat_dim
|
548 |
+
self.img_base_path = args.img_base_path
|
549 |
+
# self.use_featup = args.use_featup
|
550 |
+
self.model_path = args.model_path
|
551 |
+
self.vis_feat = args.vis_feat
|
552 |
+
self.vis_key = args.vis_key
|
553 |
+
self.focal_avg = args.focal_avg
|
554 |
+
|
555 |
+
def get_dust3r_feat(self, **kw):
|
556 |
+
model_path = os.path.join(self.model_path, "DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth")
|
557 |
+
model = AsymmetricCroCo3DStereo.from_pretrained(model_path).to(self.device)
|
558 |
+
output = inference(kw['pairs'], model, self.device, batch_size=1)
|
559 |
+
scene = global_aligner(output, device=self.device, mode=GlobalAlignerMode.PointCloudOptimizer)
|
560 |
+
if self.vis_key:
|
561 |
+
assert self.vis_key == 'decfeat', f"Expected vis_key to be 'decfeat', but got {self.vis_key}"
|
562 |
+
self.vis_decfeat(kw['pairs'], output=output)
|
563 |
+
|
564 |
+
# del model, output
|
565 |
+
# torch.cuda.empty_cache()
|
566 |
+
|
567 |
+
return scene.stacked_feat
|
568 |
+
|
569 |
+
def get_mast3r_feat(self, **kw):
|
570 |
+
model_path = os.path.join(self.model_path, "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth")
|
571 |
+
model = AsymmetricMASt3R.from_pretrained(model_path).to(self.device)
|
572 |
+
cache_dir = os.path.join(self.img_base_path, "cache")
|
573 |
+
if os.path.exists(cache_dir):
|
574 |
+
os.system(f'rm -rf {cache_dir}')
|
575 |
+
scene = sparse_global_alignment(kw['train_img_list'], kw['pairs'], cache_dir,
|
576 |
+
model, lr1=0.07, niter1=500, lr2=0.014, niter2=200, device=self.device,
|
577 |
+
opt_depth='depth' in 'refine', shared_intrinsics=self.focal_avg,
|
578 |
+
matching_conf_thr=5.)
|
579 |
+
if self.vis_key:
|
580 |
+
assert self.vis_key in ['decfeat', 'desc'], f"Expected vis_key to be 'decfeat' or 'desc', but got {self.vis_key}"
|
581 |
+
self.vis_decfeat(kw['pairs'], model=model)
|
582 |
+
|
583 |
+
# del model
|
584 |
+
# torch.cuda.empty_cache()
|
585 |
+
|
586 |
+
return scene.stacked_feat
|
587 |
+
|
588 |
+
def get_feat(self, feat_type):
|
589 |
+
"""
|
590 |
+
Get features using Probe3D.
|
591 |
+
"""
|
592 |
+
cfg = OmegaConf.load(f"configs/backbone/{feat_type}.yaml")
|
593 |
+
model = instantiate(cfg.model, output="dense", return_multilayer=False)
|
594 |
+
model = model.to(self.device)
|
595 |
+
if 'midas' in feat_type:
|
596 |
+
image_norm = adjust_midas_norm(torch.cat([i['img'] for i in self.images])).to(self.device)
|
597 |
+
# elif 'clip' in self.feat_type:
|
598 |
+
# image_norm = adjust_clip_norm(torch.cat([i['img'] for i in self.images])).to(self.device)
|
599 |
+
else:
|
600 |
+
image_norm = adjust_norm(torch.cat([i['img'] for i in self.images])).to(self.device)
|
601 |
+
|
602 |
+
with torch.no_grad():
|
603 |
+
feats = model(image_norm).permute(0, 2, 3, 1)
|
604 |
+
|
605 |
+
# del model
|
606 |
+
# torch.cuda.empty_cache()
|
607 |
+
|
608 |
+
return feats
|
609 |
+
|
610 |
+
# def get_feat(self, feat_type):
|
611 |
+
# """
|
612 |
+
# Get features using FeatUp.
|
613 |
+
# """
|
614 |
+
# original_feat_type = feat_type
|
615 |
+
# use_norm = False if 'maskclip' in feat_type else True
|
616 |
+
# if 'featup' in original_feat_type:
|
617 |
+
# feat_type = feat_type.split('_featup')[0]
|
618 |
+
# # feat_upsampler = torch.hub.load("mhamilton723/FeatUp", feat_type, use_norm=use_norm).to(device)
|
619 |
+
# feat_upsampler = torch.hub.load("/home/chenyue/.cache/torch/hub/mhamilton723_FeatUp_main/", feat_type, use_norm=use_norm, source='local').to(self.device) ## offline
|
620 |
+
# image_norm = adjust_norm(torch.cat([i['img'] for i in self.images])).to(self.device)
|
621 |
+
# image_norm = F.interpolate(image_norm, size=(224, 224), mode='bilinear', align_corners=False)
|
622 |
+
# if 'featup' in original_feat_type:
|
623 |
+
# feats = feat_upsampler(image_norm).permute(0, 2, 3, 1)
|
624 |
+
# else:
|
625 |
+
# feats = feat_upsampler.model(image_norm).permute(0, 2, 3, 1)
|
626 |
+
# return feats
|
627 |
+
|
628 |
+
def get_iuvrgb(self):
|
629 |
+
rgb = torch.cat([i['img'] for i in self.images]).permute(0, 2, 3, 1).to(self.device)
|
630 |
+
feats = torch.cat([generate_iuv(*rgb.shape[:-1]).to(self.device), rgb], dim=-1)
|
631 |
+
return feats
|
632 |
+
|
633 |
+
def get_iuv(self):
|
634 |
+
rgb = torch.cat([i['img'] for i in self.images]).permute(0, 2, 3, 1).to(self.device)
|
635 |
+
feats = generate_iuv(*rgb.shape[:-1]).to(self.device)
|
636 |
+
return feats
|
637 |
+
|
638 |
+
def preprocess(self, feature, feat_dim, vis_feat=False, is_upsample=True):
|
639 |
+
"""
|
640 |
+
Preprocess features by applying PCA, upsampling, and optionally visualizing.
|
641 |
+
"""
|
642 |
+
if feat_dim:
|
643 |
+
feature = pca(feature, feat_dim)
|
644 |
+
# else:
|
645 |
+
# feature_min = feature.min(dim=0, keepdim=True).values.min(dim=1, keepdim=True).values
|
646 |
+
# feature_max = feature.max(dim=0, keepdim=True).values.max(dim=1, keepdim=True).values
|
647 |
+
# feature = (feature - feature_min) / (feature_max - feature_min + 1e-6)
|
648 |
+
# feature = feature - feature.mean(dim=[0,1,2], keepdim=True)
|
649 |
+
|
650 |
+
torch.cuda.empty_cache()
|
651 |
+
|
652 |
+
if (feature[0].shape[0:-1] != self.images[0]['true_shape'][0]).all() and is_upsample:
|
653 |
+
feature = upsampler(feature, *[s for s in self.images[0]['true_shape'][0]])
|
654 |
+
|
655 |
+
print(f"Feature map size >>> height: {feature[0].shape[0]}, width: {feature[0].shape[1]}, channels: {feature[0].shape[2]}")
|
656 |
+
if vis_feat:
|
657 |
+
save_path = visualizer(feature, self.images, self.img_base_path, feat_type=self.feat_type)
|
658 |
+
print(f"The encoder feature visualization has been saved at >>>>> {save_path}")
|
659 |
+
|
660 |
+
return feature
|
661 |
+
|
662 |
+
def vis_decfeat(self, pairs, **kw):
|
663 |
+
"""
|
664 |
+
Visualize decoder or head(only for mast3r) features.
|
665 |
+
"""
|
666 |
+
if 'output' in kw:
|
667 |
+
output = kw['output']
|
668 |
+
else:
|
669 |
+
output = inference(pairs, kw['model'], self.device, batch_size=1, verbose=False)
|
670 |
+
decfeat1 = output['pred1'][self.vis_key].detach()
|
671 |
+
decfeat2 = output['pred2'][self.vis_key].detach()
|
672 |
+
# decfeat1 = pca(decfeat1, 9)
|
673 |
+
# decfeat2 = pca(decfeat2, 9)
|
674 |
+
decfeat = torch.stack([decfeat1, decfeat2], dim=1).view(-1, *decfeat1.shape[1:])
|
675 |
+
decfeat = torch.cat(torch.chunk(decfeat,2)[::-1], dim=0)
|
676 |
+
decfeat = pca(decfeat, 9)
|
677 |
+
if (decfeat.shape[1:-1] != self.images[0]['true_shape'][0]).all():
|
678 |
+
decfeat = upsampler(decfeat, *[s for s in self.images[0]['true_shape'][0]])
|
679 |
+
pair_images = [im for p in pairs[3:] + pairs[:3] for im in p]
|
680 |
+
save_path = mv_visualizer(decfeat, pair_images, self.img_base_path,
|
681 |
+
feat_type=self.feat_type, file_name=f'{self.vis_key}_pcaall_dim0-9')
|
682 |
+
print(f"The decoder feature visualization has been saved at >>>>> {save_path}")
|
683 |
+
|
684 |
+
def forward(self, **kw):
|
685 |
+
feat_dim = self.feat_dim
|
686 |
+
vis_feat = self.vis_feat and len(self.feat_type) == 1
|
687 |
+
is_upsample = len(self.feat_type) == 1
|
688 |
+
|
689 |
+
all_feats = {}
|
690 |
+
for feat_type in self.feat_type:
|
691 |
+
if feat_type == self.method:
|
692 |
+
feats = kw['scene'].stacked_feat
|
693 |
+
elif feat_type in ['dust3r', 'mast3r']:
|
694 |
+
feats = getattr(self, f"get_{feat_type}_feat")(**kw)
|
695 |
+
elif feat_type in ['iuv', 'iuvrgb']:
|
696 |
+
feats = getattr(self, f"get_{feat_type}")()
|
697 |
+
feat_dim = None
|
698 |
+
else:
|
699 |
+
feats = self.get_feat(feat_type)
|
700 |
+
|
701 |
+
# feats = to_numpy(self.preprocess(feats))
|
702 |
+
all_feats[feat_type] = self.preprocess(feats.detach().clone(), feat_dim, vis_feat, is_upsample)
|
703 |
+
|
704 |
+
if len(self.feat_type) > 1:
|
705 |
+
all_feats = {k: (v - v.min()) / (v.max() - v.min()) for k, v in all_feats.items()}
|
706 |
+
|
707 |
+
target_size = tuple(s // 16 for s in self.images[0]['true_shape'][0][:2])
|
708 |
+
tmp_feats = []
|
709 |
+
kickoff = []
|
710 |
+
|
711 |
+
for k, v in all_feats.items():
|
712 |
+
if k in ['iuv', 'iuvrgb']:
|
713 |
+
# self.feat_dim -= v.shape[-1]
|
714 |
+
kickoff.append(v)
|
715 |
+
else:
|
716 |
+
if v.shape[1:3] != target_size:
|
717 |
+
v = F.interpolate(v.permute(0, 3, 1, 2), size=target_size,
|
718 |
+
mode='bilinear', align_corners=False).permute(0, 2, 3, 1)
|
719 |
+
tmp_feats.append(v)
|
720 |
+
|
721 |
+
feats = self.preprocess(torch.cat(tmp_feats, dim=-1), self.feat_dim, self.vis_feat and not kickoff)
|
722 |
+
|
723 |
+
if kickoff:
|
724 |
+
feats = torch.cat([feats] + kickoff, dim=-1)
|
725 |
+
feats = self.preprocess(feats, self.feat_dim, self.vis_feat, is_upsample=False)
|
726 |
+
|
727 |
+
else:
|
728 |
+
feats = all_feats[self.feat_type[0]]
|
729 |
+
|
730 |
+
torch.cuda.empty_cache()
|
731 |
+
return to_numpy(feats)
|
732 |
+
|
733 |
+
def __call__(self, **kw):
|
734 |
+
return self.forward(**kw)
|
735 |
+
|
736 |
+
|
737 |
+
class InitMethod:
|
738 |
+
"""
|
739 |
+
Initialize pointmap and camera param via DUSt3R or MASt3R.
|
740 |
+
"""
|
741 |
+
def __init__(self, args):
|
742 |
+
self.method = args.method
|
743 |
+
self.n_views = args.n_views
|
744 |
+
self.device = args.device
|
745 |
+
self.img_base_path = args.img_base_path
|
746 |
+
self.focal_avg = args.focal_avg
|
747 |
+
self.tsdf_thresh = args.tsdf_thresh
|
748 |
+
self.min_conf_thr = args.min_conf_thr
|
749 |
+
if self.method == 'dust3r':
|
750 |
+
self.model_path = os.path.join(args.model_path, "DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth")
|
751 |
+
else:
|
752 |
+
self.model_path = os.path.join(args.model_path, "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth")
|
753 |
+
|
754 |
+
def get_dust3r(self):
|
755 |
+
return AsymmetricCroCo3DStereo.from_pretrained(self.model_path).to(self.device)
|
756 |
+
|
757 |
+
def get_mast3r(self):
|
758 |
+
return AsymmetricMASt3R.from_pretrained(self.model_path).to(self.device)
|
759 |
+
|
760 |
+
def infer_dust3r(self, **kw):
|
761 |
+
output = inference(kw['pairs'], kw['model'], self.device, batch_size=1)
|
762 |
+
scene = global_aligner(output, device=self.device, mode=GlobalAlignerMode.PointCloudOptimizer)
|
763 |
+
loss = compute_global_alignment(scene=scene, init="mst", niter=300, schedule='linear', lr=0.01,
|
764 |
+
focal_avg=self.focal_avg, known_focal=kw.get('known_focal', None))
|
765 |
+
scene = scene.clean_pointcloud()
|
766 |
+
return scene
|
767 |
+
|
768 |
+
def infer_mast3r(self, **kw):
|
769 |
+
cache_dir = os.path.join(self.img_base_path, "cache")
|
770 |
+
if os.path.exists(cache_dir):
|
771 |
+
os.system(f'rm -rf {cache_dir}')
|
772 |
+
|
773 |
+
scene = sparse_global_alignment(kw['train_img_list'], kw['pairs'], cache_dir,
|
774 |
+
kw['model'], lr1=0.07, niter1=500, lr2=0.014, niter2=200, device=self.device,
|
775 |
+
opt_depth='depth' in 'refine', shared_intrinsics=self.focal_avg,
|
776 |
+
matching_conf_thr=5.)
|
777 |
+
return scene
|
778 |
+
|
779 |
+
def get_dust3r_info(self, scene):
|
780 |
+
imgs = to_numpy(scene.imgs)
|
781 |
+
focals = scene.get_focals()
|
782 |
+
poses = to_numpy(scene.get_im_poses())
|
783 |
+
pts3d = to_numpy(scene.get_pts3d())
|
784 |
+
# pts3d = to_numpy(scene.get_planes3d())
|
785 |
+
scene.min_conf_thr = float(scene.conf_trf(torch.tensor(1.0)))
|
786 |
+
confidence_masks = to_numpy(scene.get_masks())
|
787 |
+
intrinsics = to_numpy(scene.get_intrinsics())
|
788 |
+
return imgs, focals, poses, intrinsics, pts3d, confidence_masks
|
789 |
+
|
790 |
+
def get_mast3r_info(self, scene):
|
791 |
+
imgs = to_numpy(scene.imgs)
|
792 |
+
focals = scene.get_focals()[:,None]
|
793 |
+
poses = to_numpy(scene.get_im_poses())
|
794 |
+
intrinsics = to_numpy(scene.intrinsics)
|
795 |
+
tsdf = TSDFPostProcess(scene, TSDF_thresh=self.tsdf_thresh)
|
796 |
+
pts3d, _, confs = to_numpy(tsdf.get_dense_pts3d(clean_depth=True))
|
797 |
+
pts3d = [arr.reshape((*imgs[0].shape[:2], 3)) for arr in pts3d]
|
798 |
+
confidence_masks = np.array(to_numpy([c > self.min_conf_thr for c in confs]))
|
799 |
+
return imgs, focals, poses, intrinsics, pts3d, confidence_masks
|
800 |
+
|
801 |
+
def get_dust3r_depth(self, scene):
|
802 |
+
return to_numpy(scene.get_depthmaps())
|
803 |
+
|
804 |
+
def get_mast3r_depth(self, scene):
|
805 |
+
imgs = to_numpy(scene.imgs)
|
806 |
+
tsdf = TSDFPostProcess(scene, TSDF_thresh=self.tsdf_thresh)
|
807 |
+
_, depthmaps, _ = to_numpy(tsdf.get_dense_pts3d(clean_depth=True))
|
808 |
+
depthmaps = [arr.reshape((*imgs[0].shape[:2], 3)) for arr in depthmaps]
|
809 |
+
return depthmaps
|
810 |
+
|
811 |
+
def get_model(self):
|
812 |
+
return getattr(self, f"get_{self.method}")()
|
813 |
+
|
814 |
+
def infer(self, **kw):
|
815 |
+
return getattr(self, f"infer_{self.method}")(**kw)
|
816 |
+
|
817 |
+
def get_info(self, scene):
|
818 |
+
return getattr(self, f"get_{self.method}_info")(scene)
|
819 |
+
|
820 |
+
def get_depth(self, scene):
|
821 |
+
return getattr(self, f"get_{self.method}_depth")(scene)
|
822 |
+
|
823 |
+
|
824 |
+
|
825 |
+
|
826 |
+
|
827 |
+
|
utils/general_utils.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import sys
|
14 |
+
from datetime import datetime
|
15 |
+
import numpy as np
|
16 |
+
import random
|
17 |
+
|
18 |
+
def inverse_sigmoid(x):
|
19 |
+
return torch.log(x/(1-x))
|
20 |
+
|
21 |
+
def PILtoTorch(pil_image, resolution):
|
22 |
+
resized_image_PIL = pil_image.resize(resolution)
|
23 |
+
resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
|
24 |
+
if len(resized_image.shape) == 3:
|
25 |
+
return resized_image.permute(2, 0, 1)
|
26 |
+
else:
|
27 |
+
return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
|
28 |
+
|
29 |
+
def get_expon_lr_func(
|
30 |
+
lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
Copied from Plenoxels
|
34 |
+
|
35 |
+
Continuous learning rate decay function. Adapted from JaxNeRF
|
36 |
+
The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
|
37 |
+
is log-linearly interpolated elsewhere (equivalent to exponential decay).
|
38 |
+
If lr_delay_steps>0 then the learning rate will be scaled by some smooth
|
39 |
+
function of lr_delay_mult, such that the initial learning rate is
|
40 |
+
lr_init*lr_delay_mult at the beginning of optimization but will be eased back
|
41 |
+
to the normal learning rate when steps>lr_delay_steps.
|
42 |
+
:param conf: config subtree 'lr' or similar
|
43 |
+
:param max_steps: int, the number of steps during optimization.
|
44 |
+
:return HoF which takes step as input
|
45 |
+
"""
|
46 |
+
|
47 |
+
def helper(step):
|
48 |
+
if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
|
49 |
+
# Disable this parameter
|
50 |
+
return 0.0
|
51 |
+
if lr_delay_steps > 0:
|
52 |
+
# A kind of reverse cosine decay.
|
53 |
+
delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
|
54 |
+
0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
|
55 |
+
)
|
56 |
+
else:
|
57 |
+
delay_rate = 1.0
|
58 |
+
t = np.clip(step / max_steps, 0, 1)
|
59 |
+
log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
|
60 |
+
return delay_rate * log_lerp
|
61 |
+
|
62 |
+
return helper
|
63 |
+
|
64 |
+
def strip_lowerdiag(L):
|
65 |
+
uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
|
66 |
+
|
67 |
+
uncertainty[:, 0] = L[:, 0, 0]
|
68 |
+
uncertainty[:, 1] = L[:, 0, 1]
|
69 |
+
uncertainty[:, 2] = L[:, 0, 2]
|
70 |
+
uncertainty[:, 3] = L[:, 1, 1]
|
71 |
+
uncertainty[:, 4] = L[:, 1, 2]
|
72 |
+
uncertainty[:, 5] = L[:, 2, 2]
|
73 |
+
return uncertainty
|
74 |
+
|
75 |
+
def strip_symmetric(sym):
|
76 |
+
return strip_lowerdiag(sym)
|
77 |
+
|
78 |
+
def build_rotation(r):
|
79 |
+
norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
|
80 |
+
|
81 |
+
q = r / norm[:, None]
|
82 |
+
|
83 |
+
R = torch.zeros((q.size(0), 3, 3), device='cuda')
|
84 |
+
|
85 |
+
r = q[:, 0]
|
86 |
+
x = q[:, 1]
|
87 |
+
y = q[:, 2]
|
88 |
+
z = q[:, 3]
|
89 |
+
|
90 |
+
R[:, 0, 0] = 1 - 2 * (y*y + z*z)
|
91 |
+
R[:, 0, 1] = 2 * (x*y - r*z)
|
92 |
+
R[:, 0, 2] = 2 * (x*z + r*y)
|
93 |
+
R[:, 1, 0] = 2 * (x*y + r*z)
|
94 |
+
R[:, 1, 1] = 1 - 2 * (x*x + z*z)
|
95 |
+
R[:, 1, 2] = 2 * (y*z - r*x)
|
96 |
+
R[:, 2, 0] = 2 * (x*z - r*y)
|
97 |
+
R[:, 2, 1] = 2 * (y*z + r*x)
|
98 |
+
R[:, 2, 2] = 1 - 2 * (x*x + y*y)
|
99 |
+
return R
|
100 |
+
|
101 |
+
def build_scaling_rotation(s, r):
|
102 |
+
L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
|
103 |
+
R = build_rotation(r)
|
104 |
+
|
105 |
+
L[:,0,0] = s[:,0]
|
106 |
+
L[:,1,1] = s[:,1]
|
107 |
+
L[:,2,2] = s[:,2]
|
108 |
+
|
109 |
+
L = R @ L
|
110 |
+
return L
|
111 |
+
|
112 |
+
def safe_state(silent):
|
113 |
+
old_f = sys.stdout
|
114 |
+
class F:
|
115 |
+
def __init__(self, silent):
|
116 |
+
self.silent = silent
|
117 |
+
|
118 |
+
def write(self, x):
|
119 |
+
if not self.silent:
|
120 |
+
if x.endswith("\n"):
|
121 |
+
old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
|
122 |
+
else:
|
123 |
+
old_f.write(x)
|
124 |
+
|
125 |
+
def flush(self):
|
126 |
+
old_f.flush()
|
127 |
+
|
128 |
+
sys.stdout = F(silent)
|
129 |
+
|
130 |
+
random.seed(0)
|
131 |
+
np.random.seed(0)
|
132 |
+
torch.manual_seed(0)
|
133 |
+
torch.cuda.set_device(torch.device("cuda:0"))
|
utils/graphics_utils.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import math
|
14 |
+
import numpy as np
|
15 |
+
from typing import NamedTuple
|
16 |
+
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch import Tensor
|
19 |
+
|
20 |
+
class BasicPointCloud(NamedTuple):
|
21 |
+
points : np.array
|
22 |
+
colors : np.array
|
23 |
+
normals : np.array
|
24 |
+
features: np.array
|
25 |
+
|
26 |
+
def geom_transform_points(points, transf_matrix):
|
27 |
+
P, _ = points.shape
|
28 |
+
ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
|
29 |
+
points_hom = torch.cat([points, ones], dim=1)
|
30 |
+
points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
|
31 |
+
|
32 |
+
denom = points_out[..., 3:] + 0.0000001
|
33 |
+
return (points_out[..., :3] / denom).squeeze(dim=0)
|
34 |
+
|
35 |
+
def getWorld2View(R, t):
|
36 |
+
Rt = np.zeros((4, 4))
|
37 |
+
Rt[:3, :3] = R.transpose()
|
38 |
+
Rt[:3, 3] = t
|
39 |
+
Rt[3, 3] = 1.0
|
40 |
+
return np.float32(Rt)
|
41 |
+
|
42 |
+
def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
|
43 |
+
Rt = np.zeros((4, 4))
|
44 |
+
Rt[:3, :3] = R.transpose()
|
45 |
+
Rt[:3, 3] = t
|
46 |
+
Rt[3, 3] = 1.0
|
47 |
+
|
48 |
+
C2W = np.linalg.inv(Rt)
|
49 |
+
cam_center = C2W[:3, 3]
|
50 |
+
cam_center = (cam_center + translate) * scale
|
51 |
+
C2W[:3, 3] = cam_center
|
52 |
+
Rt = np.linalg.inv(C2W)
|
53 |
+
return np.float32(Rt)
|
54 |
+
|
55 |
+
def getWorld2View2_torch(R, t, translate=torch.tensor([0.0, 0.0, 0.0]), scale=1.0):
|
56 |
+
translate = torch.tensor(translate, dtype=torch.float32)
|
57 |
+
|
58 |
+
# Initialize the transformation matrix
|
59 |
+
Rt = torch.zeros((4, 4), dtype=torch.float32)
|
60 |
+
Rt[:3, :3] = R.t() # Transpose of R
|
61 |
+
Rt[:3, 3] = t
|
62 |
+
Rt[3, 3] = 1.0
|
63 |
+
|
64 |
+
# Compute the inverse to get the camera-to-world transformation
|
65 |
+
C2W = torch.linalg.inv(Rt)
|
66 |
+
cam_center = C2W[:3, 3]
|
67 |
+
cam_center = (cam_center + translate) * scale
|
68 |
+
C2W[:3, 3] = cam_center
|
69 |
+
|
70 |
+
# Invert again to get the world-to-view transformation
|
71 |
+
Rt = torch.linalg.inv(C2W)
|
72 |
+
|
73 |
+
return Rt
|
74 |
+
|
75 |
+
def getProjectionMatrix(znear, zfar, fovX, fovY):
|
76 |
+
tanHalfFovY = math.tan((fovY / 2))
|
77 |
+
tanHalfFovX = math.tan((fovX / 2))
|
78 |
+
|
79 |
+
top = tanHalfFovY * znear
|
80 |
+
bottom = -top
|
81 |
+
right = tanHalfFovX * znear
|
82 |
+
left = -right
|
83 |
+
|
84 |
+
P = torch.zeros(4, 4)
|
85 |
+
|
86 |
+
z_sign = 1.0
|
87 |
+
|
88 |
+
P[0, 0] = 2.0 * znear / (right - left)
|
89 |
+
P[1, 1] = 2.0 * znear / (top - bottom)
|
90 |
+
P[0, 2] = (right + left) / (right - left)
|
91 |
+
P[1, 2] = (top + bottom) / (top - bottom)
|
92 |
+
P[3, 2] = z_sign
|
93 |
+
P[2, 2] = z_sign * zfar / (zfar - znear)
|
94 |
+
P[2, 3] = -(zfar * znear) / (zfar - znear)
|
95 |
+
return P
|
96 |
+
|
97 |
+
def fov2focal(fov, pixels):
|
98 |
+
return pixels / (2 * math.tan(fov / 2))
|
99 |
+
|
100 |
+
def focal2fov(focal, pixels):
|
101 |
+
return 2*math.atan(pixels/(2*focal))
|
102 |
+
|
103 |
+
def resize_render(view, size=None):
|
104 |
+
image_size = size if size is not None else max(view.image_width, view.image_height)
|
105 |
+
view.original_image = torch.zeros((3, image_size, image_size), device=view.original_image.device)
|
106 |
+
focal_length_x = fov2focal(view.FoVx, view.image_width)
|
107 |
+
focal_length_y = fov2focal(view.FoVy, view.image_height)
|
108 |
+
view.image_width = image_size
|
109 |
+
view.image_height = image_size
|
110 |
+
view.FoVx = focal2fov(focal_length_x, image_size)
|
111 |
+
view.FoVy = focal2fov(focal_length_y, image_size)
|
112 |
+
return view
|
113 |
+
|
114 |
+
def make_video_divisble(
|
115 |
+
video: torch.Tensor | np.ndarray, block_size=16
|
116 |
+
) -> torch.Tensor | np.ndarray:
|
117 |
+
H, W = video.shape[1:3]
|
118 |
+
H_new = H - H % block_size
|
119 |
+
W_new = W - W % block_size
|
120 |
+
return video[:, :H_new, :W_new]
|
121 |
+
|
122 |
+
|
123 |
+
def depth_to_points(
|
124 |
+
depths: Tensor, camtoworlds: Tensor, Ks: Tensor, z_depth: bool = True
|
125 |
+
) -> Tensor:
|
126 |
+
"""Convert depth maps to 3D points
|
127 |
+
|
128 |
+
Args:
|
129 |
+
depths: Depth maps [..., H, W, 1]
|
130 |
+
camtoworlds: Camera-to-world transformation matrices [..., 4, 4]
|
131 |
+
Ks: Camera intrinsics [..., 3, 3]
|
132 |
+
z_depth: Whether the depth is in z-depth (True) or ray depth (False)
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
points: 3D points in the world coordinate system [..., H, W, 3]
|
136 |
+
"""
|
137 |
+
assert depths.shape[-1] == 1, f"Invalid depth shape: {depths.shape}"
|
138 |
+
assert camtoworlds.shape[-2:] == (
|
139 |
+
4,
|
140 |
+
4,
|
141 |
+
), f"Invalid viewmats shape: {camtoworlds.shape}"
|
142 |
+
assert Ks.shape[-2:] == (3, 3), f"Invalid Ks shape: {Ks.shape}"
|
143 |
+
assert (
|
144 |
+
depths.shape[:-3] == camtoworlds.shape[:-2] == Ks.shape[:-2]
|
145 |
+
), f"Shape mismatch! depths: {depths.shape}, viewmats: {camtoworlds.shape}, Ks: {Ks.shape}"
|
146 |
+
|
147 |
+
device = depths.device
|
148 |
+
height, width = depths.shape[-3:-1]
|
149 |
+
|
150 |
+
x, y = torch.meshgrid(
|
151 |
+
torch.arange(width, device=device),
|
152 |
+
torch.arange(height, device=device),
|
153 |
+
indexing="xy",
|
154 |
+
) # [H, W]
|
155 |
+
|
156 |
+
fx = Ks[..., 0, 0] # [...]
|
157 |
+
fy = Ks[..., 1, 1] # [...]
|
158 |
+
cx = Ks[..., 0, 2] # [...]
|
159 |
+
cy = Ks[..., 1, 2] # [...]
|
160 |
+
|
161 |
+
# camera directions in camera coordinates
|
162 |
+
camera_dirs = F.pad(
|
163 |
+
torch.stack(
|
164 |
+
[
|
165 |
+
(x - cx[..., None, None] + 0.5) / fx[..., None, None],
|
166 |
+
(y - cy[..., None, None] + 0.5) / fy[..., None, None],
|
167 |
+
],
|
168 |
+
dim=-1,
|
169 |
+
),
|
170 |
+
(0, 1),
|
171 |
+
value=1.0,
|
172 |
+
) # [..., H, W, 3]
|
173 |
+
|
174 |
+
# ray directions in world coordinates
|
175 |
+
directions = torch.einsum(
|
176 |
+
"...ij,...hwj->...hwi", camtoworlds[..., :3, :3], camera_dirs
|
177 |
+
) # [..., H, W, 3]
|
178 |
+
origins = camtoworlds[..., :3, -1] # [..., 3]
|
179 |
+
|
180 |
+
if not z_depth:
|
181 |
+
directions = F.normalize(directions, dim=-1)
|
182 |
+
|
183 |
+
points = origins[..., None, None, :] + depths * directions
|
184 |
+
return points
|
185 |
+
|
186 |
+
|
187 |
+
def depth_to_normal(
|
188 |
+
depths: Tensor, camtoworlds: Tensor, Ks: Tensor, z_depth: bool = True
|
189 |
+
) -> Tensor:
|
190 |
+
"""Convert depth maps to surface normals
|
191 |
+
|
192 |
+
Args:
|
193 |
+
depths: Depth maps [..., H, W, 1]
|
194 |
+
camtoworlds: Camera-to-world transformation matrices [..., 4, 4]
|
195 |
+
Ks: Camera intrinsics [..., 3, 3]
|
196 |
+
z_depth: Whether the depth is in z-depth (True) or ray depth (False)
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
normals: Surface normals in the world coordinate system [..., H, W, 3]
|
200 |
+
"""
|
201 |
+
points = depth_to_points(depths, camtoworlds, Ks, z_depth=z_depth) # [..., H, W, 3]
|
202 |
+
dx = torch.cat(
|
203 |
+
[points[..., 2:, 1:-1, :] - points[..., :-2, 1:-1, :]], dim=-3
|
204 |
+
) # [..., H-2, W-2, 3]
|
205 |
+
dy = torch.cat(
|
206 |
+
[points[..., 1:-1, 2:, :] - points[..., 1:-1, :-2, :]], dim=-2
|
207 |
+
) # [..., H-2, W-2, 3]
|
208 |
+
normals = F.normalize(torch.cross(dx, dy, dim=-1), dim=-1) # [..., H-2, W-2, 3]
|
209 |
+
normals = F.pad(normals, (0, 0, 1, 1, 1, 1), value=0.0) # [..., H, W, 3]
|
210 |
+
return normals
|
utils/image_utils.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
def mse(img1, img2):
|
15 |
+
return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
|
16 |
+
|
17 |
+
def psnr(img1, img2):
|
18 |
+
mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
|
19 |
+
return 20 * torch.log10(1.0 / torch.sqrt(mse))
|
20 |
+
|
21 |
+
def masked_psnr(img1, img2, mask):
|
22 |
+
mse = ((((img1 - img2)) ** 2) * mask).sum() / (3. * mask.sum())
|
23 |
+
return 20 * torch.log10(1.0 / torch.sqrt(mse))
|
24 |
+
|
25 |
+
|
26 |
+
def accuracy_torch(gt_points, rec_points, gt_normals=None, rec_normals=None, batch_size=5000):
|
27 |
+
n_points = rec_points.shape[0]
|
28 |
+
all_distances = []
|
29 |
+
all_indices = []
|
30 |
+
|
31 |
+
for i in range(0, n_points, batch_size):
|
32 |
+
end_idx = min(i + batch_size, n_points)
|
33 |
+
batch_points = rec_points[i:end_idx]
|
34 |
+
|
35 |
+
distances = torch.cdist(batch_points, gt_points) # (batch_size, M)
|
36 |
+
batch_distances, batch_indices = torch.min(distances, dim=1) # (batch_size,)
|
37 |
+
|
38 |
+
all_distances.append(batch_distances)
|
39 |
+
all_indices.append(batch_indices)
|
40 |
+
|
41 |
+
distances = torch.cat(all_distances)
|
42 |
+
indices = torch.cat(all_indices)
|
43 |
+
|
44 |
+
acc = torch.mean(distances)
|
45 |
+
acc_median = torch.median(distances)
|
46 |
+
|
47 |
+
if gt_normals is not None and rec_normals is not None:
|
48 |
+
normal_dot = torch.sum(gt_normals[indices] * rec_normals, dim=-1)
|
49 |
+
normal_dot = torch.abs(normal_dot)
|
50 |
+
return acc, acc_median, torch.mean(normal_dot), torch.median(normal_dot)
|
51 |
+
|
52 |
+
return acc, acc_median
|
53 |
+
|
54 |
+
def completion_torch(gt_points, rec_points, gt_normals=None, rec_normals=None, batch_size=5000):
|
55 |
+
|
56 |
+
n_points = gt_points.shape[0]
|
57 |
+
all_distances = []
|
58 |
+
all_indices = []
|
59 |
+
|
60 |
+
for i in range(0, n_points, batch_size):
|
61 |
+
end_idx = min(i + batch_size, n_points)
|
62 |
+
batch_points = gt_points[i:end_idx]
|
63 |
+
|
64 |
+
distances = torch.cdist(batch_points, rec_points) # (batch_size, M)
|
65 |
+
batch_distances, batch_indices = torch.min(distances, dim=1) # (batch_size,)
|
66 |
+
|
67 |
+
all_distances.append(batch_distances)
|
68 |
+
all_indices.append(batch_indices)
|
69 |
+
|
70 |
+
distances = torch.cat(all_distances)
|
71 |
+
indices = torch.cat(all_indices)
|
72 |
+
|
73 |
+
comp = torch.mean(distances)
|
74 |
+
comp_median = torch.median(distances)
|
75 |
+
|
76 |
+
if gt_normals is not None and rec_normals is not None:
|
77 |
+
normal_dot = torch.sum(gt_normals * rec_normals[indices], dim=-1)
|
78 |
+
normal_dot = torch.abs(normal_dot)
|
79 |
+
return comp, comp_median, torch.mean(normal_dot), torch.median(normal_dot)
|
80 |
+
|
81 |
+
return comp, comp_median
|
82 |
+
|
83 |
+
def accuracy_per_point(gt_points, rec_points, batch_size=5000):
|
84 |
+
n_points = rec_points.shape[0]
|
85 |
+
all_distances = []
|
86 |
+
all_indices = []
|
87 |
+
|
88 |
+
for i in range(0, n_points, batch_size):
|
89 |
+
end_idx = min(i + batch_size, n_points)
|
90 |
+
batch_points = rec_points[i:end_idx]
|
91 |
+
|
92 |
+
distances = torch.cdist(batch_points, gt_points) # (batch_size, M)
|
93 |
+
batch_distances, batch_indices = torch.min(distances, dim=1) # (batch_size,)
|
94 |
+
|
95 |
+
all_distances.append(batch_distances)
|
96 |
+
all_indices.append(batch_indices)
|
97 |
+
|
98 |
+
distances = torch.cat(all_distances)
|
99 |
+
return distances
|
100 |
+
|
101 |
+
def completion_per_point(gt_points, rec_points, batch_size=5000):
|
102 |
+
|
103 |
+
n_points = gt_points.shape[0]
|
104 |
+
all_distances = []
|
105 |
+
all_indices = []
|
106 |
+
|
107 |
+
for i in range(0, n_points, batch_size):
|
108 |
+
end_idx = min(i + batch_size, n_points)
|
109 |
+
batch_points = gt_points[i:end_idx]
|
110 |
+
|
111 |
+
distances = torch.cdist(batch_points, rec_points) # (batch_size, M)
|
112 |
+
batch_distances, batch_indices = torch.min(distances, dim=1) # (batch_size,)
|
113 |
+
|
114 |
+
all_distances.append(batch_distances)
|
115 |
+
all_indices.append(batch_indices)
|
116 |
+
|
117 |
+
distances = torch.cat(all_distances)
|
118 |
+
return distances
|
utils/loss_utils.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torch.autograd import Variable
|
15 |
+
from math import exp
|
16 |
+
import einops
|
17 |
+
|
18 |
+
def l1_loss(network_output, gt):
|
19 |
+
return torch.abs((network_output - gt)).mean()
|
20 |
+
|
21 |
+
def l2_loss(network_output, gt):
|
22 |
+
return ((network_output - gt) ** 2).mean()
|
23 |
+
|
24 |
+
def gaussian(window_size, sigma):
|
25 |
+
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
|
26 |
+
return gauss / gauss.sum()
|
27 |
+
|
28 |
+
def create_window(window_size, channel):
|
29 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
30 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
31 |
+
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
32 |
+
return window
|
33 |
+
|
34 |
+
def masked_ssim(img1, img2, mask):
|
35 |
+
ssim_map = ssim(img1, img2, get_ssim_map=True)
|
36 |
+
return (ssim_map * mask).sum() / (3. * mask.sum())
|
37 |
+
|
38 |
+
|
39 |
+
def ssim(img1, img2, window_size=11, size_average=True, get_ssim_map=False):
|
40 |
+
channel = img1.size(-3)
|
41 |
+
window = create_window(window_size, channel)
|
42 |
+
|
43 |
+
if img1.is_cuda:
|
44 |
+
window = window.cuda(img1.get_device())
|
45 |
+
window = window.type_as(img1)
|
46 |
+
|
47 |
+
return _ssim(img1, img2, window, window_size, channel, size_average, get_ssim_map)
|
48 |
+
|
49 |
+
def _ssim(img1, img2, window, window_size, channel, size_average=True, get_ssim_map=False):
|
50 |
+
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
51 |
+
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
52 |
+
|
53 |
+
mu1_sq = mu1.pow(2)
|
54 |
+
mu2_sq = mu2.pow(2)
|
55 |
+
mu1_mu2 = mu1 * mu2
|
56 |
+
|
57 |
+
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
58 |
+
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
59 |
+
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
60 |
+
|
61 |
+
C1 = 0.01 ** 2
|
62 |
+
C2 = 0.03 ** 2
|
63 |
+
|
64 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
65 |
+
|
66 |
+
if get_ssim_map:
|
67 |
+
return ssim_map
|
68 |
+
elif size_average:
|
69 |
+
return ssim_map.mean()
|
70 |
+
else:
|
71 |
+
return ssim_map.mean(1).mean(1).mean(1)
|
72 |
+
|
73 |
+
|
74 |
+
# --- Projections ---
|
75 |
+
|
76 |
+
def homogenize_points(points):
|
77 |
+
"""Append a '1' along the final dimension of the tensor (i.e. convert xyz->xyz1)"""
|
78 |
+
return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
|
79 |
+
|
80 |
+
|
81 |
+
def normalize_homogenous_points(points):
|
82 |
+
"""Normalize the point vectors"""
|
83 |
+
return points / points[..., -1:]
|
84 |
+
|
85 |
+
|
86 |
+
def pixel_space_to_camera_space(pixel_space_points, depth, intrinsics):
|
87 |
+
"""
|
88 |
+
Convert pixel space points to camera space points.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
pixel_space_points (torch.Tensor): Pixel space points with shape (h, w, 2)
|
92 |
+
depth (torch.Tensor): Depth map with shape (b, v, h, w, 1)
|
93 |
+
intrinsics (torch.Tensor): Camera intrinsics with shape (b, v, 3, 3)
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
torch.Tensor: Camera space points with shape (b, v, h, w, 3).
|
97 |
+
"""
|
98 |
+
pixel_space_points = homogenize_points(pixel_space_points)
|
99 |
+
camera_space_points = torch.einsum('b v i j , h w j -> b v h w i', intrinsics.inverse(), pixel_space_points)
|
100 |
+
camera_space_points = camera_space_points * depth
|
101 |
+
return camera_space_points
|
102 |
+
|
103 |
+
|
104 |
+
def camera_space_to_world_space(camera_space_points, c2w):
|
105 |
+
"""
|
106 |
+
Convert camera space points to world space points.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
camera_space_points (torch.Tensor): Camera space points with shape (b, v, h, w, 3)
|
110 |
+
c2w (torch.Tensor): Camera to world extrinsics matrix with shape (b, v, 4, 4)
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
torch.Tensor: World space points with shape (b, v, h, w, 3).
|
114 |
+
"""
|
115 |
+
camera_space_points = homogenize_points(camera_space_points)
|
116 |
+
world_space_points = torch.einsum('b v i j , b v h w j -> b v h w i', c2w, camera_space_points)
|
117 |
+
return world_space_points[..., :3]
|
118 |
+
|
119 |
+
|
120 |
+
def camera_space_to_pixel_space(camera_space_points, intrinsics):
|
121 |
+
"""
|
122 |
+
Convert camera space points to pixel space points.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
camera_space_points (torch.Tensor): Camera space points with shape (b, v1, v2, h, w, 3)
|
126 |
+
c2w (torch.Tensor): Camera to world extrinsics matrix with shape (b, v2, 3, 3)
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
torch.Tensor: World space points with shape (b, v1, v2, h, w, 2).
|
130 |
+
"""
|
131 |
+
camera_space_points = normalize_homogenous_points(camera_space_points)
|
132 |
+
pixel_space_points = torch.einsum('b u i j , b v u h w j -> b v u h w i', intrinsics, camera_space_points)
|
133 |
+
return pixel_space_points[..., :2]
|
134 |
+
|
135 |
+
|
136 |
+
def world_space_to_camera_space(world_space_points, c2w):
|
137 |
+
"""
|
138 |
+
Convert world space points to pixel space points.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
world_space_points (torch.Tensor): World space points with shape (b, v1, h, w, 3)
|
142 |
+
c2w (torch.Tensor): Camera to world extrinsics matrix with shape (b, v2, 4, 4)
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
torch.Tensor: Camera space points with shape (b, v1, v2, h, w, 3).
|
146 |
+
"""
|
147 |
+
world_space_points = homogenize_points(world_space_points)
|
148 |
+
camera_space_points = torch.einsum('b u i j , b v h w j -> b v u h w i', c2w.inverse(), world_space_points)
|
149 |
+
return camera_space_points[..., :3]
|
150 |
+
|
151 |
+
|
152 |
+
def unproject_depth(depth, intrinsics, c2w):
|
153 |
+
"""
|
154 |
+
Turn the depth map into a 3D point cloud in world space
|
155 |
+
|
156 |
+
Args:
|
157 |
+
depth: (b, v, h, w, 1)
|
158 |
+
intrinsics: (b, v, 3, 3)
|
159 |
+
c2w: (b, v, 4, 4)
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
torch.Tensor: World space points with shape (b, v, h, w, 3).
|
163 |
+
"""
|
164 |
+
|
165 |
+
# Compute indices of pixels
|
166 |
+
h, w = depth.shape[-3], depth.shape[-2]
|
167 |
+
x_grid, y_grid = torch.meshgrid(
|
168 |
+
torch.arange(w, device=depth.device, dtype=torch.float32),
|
169 |
+
torch.arange(h, device=depth.device, dtype=torch.float32),
|
170 |
+
indexing='xy'
|
171 |
+
) # (h, w), (h, w)
|
172 |
+
|
173 |
+
# Compute coordinates of pixels in camera space
|
174 |
+
pixel_space_points = torch.stack((x_grid, y_grid), dim=-1) # (..., h, w, 2)
|
175 |
+
camera_points = pixel_space_to_camera_space(pixel_space_points, depth, intrinsics) # (..., h, w, 3)
|
176 |
+
|
177 |
+
# Convert points to world space
|
178 |
+
world_points = camera_space_to_world_space(camera_points, c2w) # (..., h, w, 3)
|
179 |
+
|
180 |
+
return world_points
|
181 |
+
|
182 |
+
|
183 |
+
@torch.no_grad()
|
184 |
+
def calculate_in_frustum_mask(depth_1, intrinsics_1, c2w_1, depth_2, intrinsics_2, c2w_2, atol=1e-2):
|
185 |
+
"""
|
186 |
+
A function that takes in the depth, intrinsics and c2w matrices of two sets
|
187 |
+
of views, and then works out which of the pixels in the first set of views
|
188 |
+
has a direct corresponding pixel in any of views in the second set
|
189 |
+
|
190 |
+
Args:
|
191 |
+
depth_1: (b, v1, h, w)
|
192 |
+
intrinsics_1: (b, v1, 3, 3)
|
193 |
+
c2w_1: (b, v1, 4, 4)
|
194 |
+
depth_2: (b, v2, h, w)
|
195 |
+
intrinsics_2: (b, v2, 3, 3)
|
196 |
+
c2w_2: (b, v2, 4, 4)
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
torch.Tensor: Camera space points with shape (b, v1, h, w).
|
200 |
+
"""
|
201 |
+
|
202 |
+
_, v1, h, w = depth_1.shape
|
203 |
+
_, v2, _, _ = depth_2.shape
|
204 |
+
|
205 |
+
# Unproject the depth to get the 3D points in world space
|
206 |
+
points_3d = unproject_depth(depth_1[..., None], intrinsics_1, c2w_1) # (b, v1, h, w, 3)
|
207 |
+
|
208 |
+
# Project the 3D points into the pixel space of all the second views simultaneously
|
209 |
+
camera_points = world_space_to_camera_space(points_3d, c2w_2) # (b, v1, v2, h, w, 3)
|
210 |
+
points_2d = camera_space_to_pixel_space(camera_points, intrinsics_2) # (b, v1, v2, h, w, 2)
|
211 |
+
|
212 |
+
# Calculate the depth of each point
|
213 |
+
rendered_depth = camera_points[..., 2] # (b, v1, v2, h, w)
|
214 |
+
|
215 |
+
# We use three conditions to determine if a point should be masked
|
216 |
+
|
217 |
+
# Condition 1: Check if the points are in the frustum of any of the v2 views
|
218 |
+
in_frustum_mask = (
|
219 |
+
(points_2d[..., 0] > 0) &
|
220 |
+
(points_2d[..., 0] < w) &
|
221 |
+
(points_2d[..., 1] > 0) &
|
222 |
+
(points_2d[..., 1] < h)
|
223 |
+
) # (b, v1, v2, h, w)
|
224 |
+
in_frustum_mask = in_frustum_mask.any(dim=-3) # (b, v1, h, w)
|
225 |
+
|
226 |
+
# Condition 2: Check if the points have non-zero (i.e. valid) depth in the input view
|
227 |
+
non_zero_depth = depth_1 > 1e-6
|
228 |
+
|
229 |
+
# Condition 3: Check if the points have matching depth to any of the v2
|
230 |
+
# views torch.nn.functional.grid_sample expects the input coordinates to
|
231 |
+
# be normalized to the range [-1, 1], so we normalize first
|
232 |
+
points_2d[..., 0] /= w
|
233 |
+
points_2d[..., 1] /= h
|
234 |
+
points_2d = points_2d * 2 - 1
|
235 |
+
matching_depth = torch.ones_like(rendered_depth, dtype=torch.bool)
|
236 |
+
for b in range(depth_1.shape[0]):
|
237 |
+
for i in range(v1):
|
238 |
+
for j in range(v2):
|
239 |
+
depth = einops.rearrange(depth_2[b, j], 'h w -> 1 1 h w')
|
240 |
+
coords = einops.rearrange(points_2d[b, i, j], 'h w c -> 1 h w c')
|
241 |
+
sampled_depths = torch.nn.functional.grid_sample(depth, coords, align_corners=False)[0, 0]
|
242 |
+
matching_depth[b, i, j] = torch.isclose(rendered_depth[b, i, j], sampled_depths, atol=atol)
|
243 |
+
|
244 |
+
matching_depth = matching_depth.any(dim=-3) # (..., v1, h, w)
|
245 |
+
|
246 |
+
mask = in_frustum_mask & non_zero_depth & matching_depth
|
247 |
+
return mask
|
utils/pose_utils.py
ADDED
@@ -0,0 +1,570 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from typing import Tuple
|
6 |
+
from utils.stepfun import sample_np, sample
|
7 |
+
import scipy
|
8 |
+
|
9 |
+
|
10 |
+
def quad2rotation(q):
|
11 |
+
"""
|
12 |
+
Convert quaternion to rotation in batch. Since all operation in pytorch, support gradient passing.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
quad (tensor, batch_size*4): quaternion.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
rot_mat (tensor, batch_size*3*3): rotation.
|
19 |
+
"""
|
20 |
+
# bs = quad.shape[0]
|
21 |
+
# qr, qi, qj, qk = quad[:, 0], quad[:, 1], quad[:, 2], quad[:, 3]
|
22 |
+
# two_s = 2.0 / (quad * quad).sum(-1)
|
23 |
+
# rot_mat = torch.zeros(bs, 3, 3).to(quad.get_device())
|
24 |
+
# rot_mat[:, 0, 0] = 1 - two_s * (qj**2 + qk**2)
|
25 |
+
# rot_mat[:, 0, 1] = two_s * (qi * qj - qk * qr)
|
26 |
+
# rot_mat[:, 0, 2] = two_s * (qi * qk + qj * qr)
|
27 |
+
# rot_mat[:, 1, 0] = two_s * (qi * qj + qk * qr)
|
28 |
+
# rot_mat[:, 1, 1] = 1 - two_s * (qi**2 + qk**2)
|
29 |
+
# rot_mat[:, 1, 2] = two_s * (qj * qk - qi * qr)
|
30 |
+
# rot_mat[:, 2, 0] = two_s * (qi * qk - qj * qr)
|
31 |
+
# rot_mat[:, 2, 1] = two_s * (qj * qk + qi * qr)
|
32 |
+
# rot_mat[:, 2, 2] = 1 - two_s * (qi**2 + qj**2)
|
33 |
+
# return rot_mat
|
34 |
+
if not isinstance(q, torch.Tensor):
|
35 |
+
q = torch.tensor(q).cuda()
|
36 |
+
|
37 |
+
norm = torch.sqrt(
|
38 |
+
q[:, 0] * q[:, 0] + q[:, 1] * q[:, 1] + q[:, 2] * q[:, 2] + q[:, 3] * q[:, 3]
|
39 |
+
)
|
40 |
+
q = q / norm[:, None]
|
41 |
+
rot = torch.zeros((q.size(0), 3, 3)).to(q)
|
42 |
+
r = q[:, 0]
|
43 |
+
x = q[:, 1]
|
44 |
+
y = q[:, 2]
|
45 |
+
z = q[:, 3]
|
46 |
+
rot[:, 0, 0] = 1 - 2 * (y * y + z * z)
|
47 |
+
rot[:, 0, 1] = 2 * (x * y - r * z)
|
48 |
+
rot[:, 0, 2] = 2 * (x * z + r * y)
|
49 |
+
rot[:, 1, 0] = 2 * (x * y + r * z)
|
50 |
+
rot[:, 1, 1] = 1 - 2 * (x * x + z * z)
|
51 |
+
rot[:, 1, 2] = 2 * (y * z - r * x)
|
52 |
+
rot[:, 2, 0] = 2 * (x * z - r * y)
|
53 |
+
rot[:, 2, 1] = 2 * (y * z + r * x)
|
54 |
+
rot[:, 2, 2] = 1 - 2 * (x * x + y * y)
|
55 |
+
return rot
|
56 |
+
|
57 |
+
def get_camera_from_tensor(inputs):
|
58 |
+
"""
|
59 |
+
Convert quaternion and translation to transformation matrix.
|
60 |
+
|
61 |
+
"""
|
62 |
+
if not isinstance(inputs, torch.Tensor):
|
63 |
+
inputs = torch.tensor(inputs).cuda()
|
64 |
+
|
65 |
+
N = len(inputs.shape)
|
66 |
+
if N == 1:
|
67 |
+
inputs = inputs.unsqueeze(0)
|
68 |
+
# quad, T = inputs[:, :4], inputs[:, 4:]
|
69 |
+
# # normalize quad
|
70 |
+
# quad = F.normalize(quad)
|
71 |
+
# R = quad2rotation(quad)
|
72 |
+
# RT = torch.cat([R, T[:, :, None]], 2)
|
73 |
+
# # Add homogenous row
|
74 |
+
# homogenous_row = torch.tensor([0, 0, 0, 1]).cuda()
|
75 |
+
# RT = torch.cat([RT, homogenous_row[None, None, :].repeat(N, 1, 1)], 1)
|
76 |
+
# if N == 1:
|
77 |
+
# RT = RT[0]
|
78 |
+
# return RT
|
79 |
+
|
80 |
+
quad, T = inputs[:, :4], inputs[:, 4:]
|
81 |
+
w2c = torch.eye(4).to(inputs).float()
|
82 |
+
w2c[:3, :3] = quad2rotation(quad)
|
83 |
+
w2c[:3, 3] = T
|
84 |
+
return w2c
|
85 |
+
|
86 |
+
def quadmultiply(q1, q2):
|
87 |
+
"""
|
88 |
+
Multiply two quaternions together using quaternion arithmetic
|
89 |
+
"""
|
90 |
+
# Extract scalar and vector parts of the quaternions
|
91 |
+
w1, x1, y1, z1 = q1.unbind(dim=-1)
|
92 |
+
w2, x2, y2, z2 = q2.unbind(dim=-1)
|
93 |
+
# Calculate the quaternion product
|
94 |
+
result_quaternion = torch.stack(
|
95 |
+
[
|
96 |
+
w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
|
97 |
+
w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
|
98 |
+
w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
|
99 |
+
w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2,
|
100 |
+
],
|
101 |
+
dim=-1,
|
102 |
+
)
|
103 |
+
|
104 |
+
return result_quaternion
|
105 |
+
|
106 |
+
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
107 |
+
"""
|
108 |
+
Returns torch.sqrt(torch.max(0, x))
|
109 |
+
but with a zero subgradient where x is 0.
|
110 |
+
Source: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#matrix_to_quaternion
|
111 |
+
"""
|
112 |
+
ret = torch.zeros_like(x)
|
113 |
+
positive_mask = x > 0
|
114 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
115 |
+
return ret
|
116 |
+
|
117 |
+
def rotation2quad(matrix: torch.Tensor) -> torch.Tensor:
|
118 |
+
"""
|
119 |
+
Convert rotations given as rotation matrices to quaternions.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
126 |
+
Source: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#matrix_to_quaternion
|
127 |
+
"""
|
128 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
129 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
130 |
+
|
131 |
+
if not isinstance(matrix, torch.Tensor):
|
132 |
+
matrix = torch.tensor(matrix).cuda()
|
133 |
+
|
134 |
+
batch_dim = matrix.shape[:-2]
|
135 |
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
|
136 |
+
matrix.reshape(batch_dim + (9,)), dim=-1
|
137 |
+
)
|
138 |
+
|
139 |
+
q_abs = _sqrt_positive_part(
|
140 |
+
torch.stack(
|
141 |
+
[
|
142 |
+
1.0 + m00 + m11 + m22,
|
143 |
+
1.0 + m00 - m11 - m22,
|
144 |
+
1.0 - m00 + m11 - m22,
|
145 |
+
1.0 - m00 - m11 + m22,
|
146 |
+
],
|
147 |
+
dim=-1,
|
148 |
+
)
|
149 |
+
)
|
150 |
+
|
151 |
+
# we produce the desired quaternion multiplied by each of r, i, j, k
|
152 |
+
quat_by_rijk = torch.stack(
|
153 |
+
[
|
154 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
155 |
+
# `int`.
|
156 |
+
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
157 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
158 |
+
# `int`.
|
159 |
+
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
160 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
161 |
+
# `int`.
|
162 |
+
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
163 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
164 |
+
# `int`.
|
165 |
+
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
166 |
+
],
|
167 |
+
dim=-2,
|
168 |
+
)
|
169 |
+
|
170 |
+
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
|
171 |
+
# the candidate won't be picked.
|
172 |
+
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
173 |
+
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
174 |
+
|
175 |
+
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
176 |
+
# forall i; we pick the best-conditioned one (with the largest denominator)
|
177 |
+
|
178 |
+
return quat_candidates[
|
179 |
+
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
|
180 |
+
].reshape(batch_dim + (4,))
|
181 |
+
|
182 |
+
|
183 |
+
def get_tensor_from_camera(RT, Tquad=False):
|
184 |
+
"""
|
185 |
+
Convert transformation matrix to quaternion and translation.
|
186 |
+
|
187 |
+
"""
|
188 |
+
# gpu_id = -1
|
189 |
+
# if type(RT) == torch.Tensor:
|
190 |
+
# if RT.get_device() != -1:
|
191 |
+
# gpu_id = RT.get_device()
|
192 |
+
# RT = RT.detach().cpu()
|
193 |
+
# RT = RT.numpy()
|
194 |
+
# from mathutils import Matrix
|
195 |
+
#
|
196 |
+
# R, T = RT[:3, :3], RT[:3, 3]
|
197 |
+
# rot = Matrix(R)
|
198 |
+
# quad = rot.to_quaternion()
|
199 |
+
# if Tquad:
|
200 |
+
# tensor = np.concatenate([T, quad], 0)
|
201 |
+
# else:
|
202 |
+
# tensor = np.concatenate([quad, T], 0)
|
203 |
+
# tensor = torch.from_numpy(tensor).float()
|
204 |
+
# if gpu_id != -1:
|
205 |
+
# tensor = tensor.to(gpu_id)
|
206 |
+
# return tensor
|
207 |
+
|
208 |
+
if not isinstance(RT, torch.Tensor):
|
209 |
+
RT = torch.tensor(RT).cuda()
|
210 |
+
|
211 |
+
rot = RT[:3, :3].unsqueeze(0).detach()
|
212 |
+
quat = rotation2quad(rot).squeeze()
|
213 |
+
tran = RT[:3, 3].detach()
|
214 |
+
|
215 |
+
return torch.cat([quat, tran])
|
216 |
+
|
217 |
+
def normalize(x):
|
218 |
+
return x / np.linalg.norm(x)
|
219 |
+
|
220 |
+
|
221 |
+
def viewmatrix(lookdir, up, position, subtract_position=False):
|
222 |
+
"""Construct lookat view matrix."""
|
223 |
+
vec2 = normalize((lookdir - position) if subtract_position else lookdir)
|
224 |
+
vec0 = normalize(np.cross(up, vec2))
|
225 |
+
vec1 = normalize(np.cross(vec2, vec0))
|
226 |
+
m = np.stack([vec0, vec1, vec2, position], axis=1)
|
227 |
+
return m
|
228 |
+
|
229 |
+
|
230 |
+
def poses_avg(poses):
|
231 |
+
"""New pose using average position, z-axis, and up vector of input poses."""
|
232 |
+
position = poses[:, :3, 3].mean(0)
|
233 |
+
z_axis = poses[:, :3, 2].mean(0)
|
234 |
+
up = poses[:, :3, 1].mean(0)
|
235 |
+
cam2world = viewmatrix(z_axis, up, position)
|
236 |
+
return cam2world
|
237 |
+
|
238 |
+
|
239 |
+
def focus_point_fn(poses):
|
240 |
+
"""Calculate nearest point to all focal axes in poses."""
|
241 |
+
directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4]
|
242 |
+
m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1])
|
243 |
+
mt_m = np.transpose(m, [0, 2, 1]) @ m
|
244 |
+
focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0]
|
245 |
+
return focus_pt
|
246 |
+
|
247 |
+
|
248 |
+
def pad_poses(p):
|
249 |
+
"""Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1]."""
|
250 |
+
bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape)
|
251 |
+
return np.concatenate([p[..., :3, :4], bottom], axis=-2)
|
252 |
+
|
253 |
+
def unpad_poses(p):
|
254 |
+
"""Remove the homogeneous bottom row from [..., 4, 4] pose matrices."""
|
255 |
+
return p[..., :3, :4]
|
256 |
+
|
257 |
+
def transform_poses_pca(poses):
|
258 |
+
"""Transforms poses so principal components lie on XYZ axes.
|
259 |
+
|
260 |
+
Args:
|
261 |
+
poses: a (N, 3, 4) array containing the cameras' camera to world transforms.
|
262 |
+
|
263 |
+
Returns:
|
264 |
+
A tuple (poses, transform), with the transformed poses and the applied
|
265 |
+
camera_to_world transforms.
|
266 |
+
"""
|
267 |
+
t = poses[:, :3, 3]
|
268 |
+
t_mean = t.mean(axis=0)
|
269 |
+
t = t - t_mean
|
270 |
+
|
271 |
+
eigval, eigvec = np.linalg.eig(t.T @ t)
|
272 |
+
# Sort eigenvectors in order of largest to smallest eigenvalue.
|
273 |
+
inds = np.argsort(eigval)[::-1]
|
274 |
+
eigvec = eigvec[:, inds]
|
275 |
+
rot = eigvec.T
|
276 |
+
if np.linalg.det(rot) < 0:
|
277 |
+
rot = np.diag(np.array([1, 1, -1])) @ rot
|
278 |
+
|
279 |
+
transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1)
|
280 |
+
poses_recentered = unpad_poses(transform @ pad_poses(poses))
|
281 |
+
transform = np.concatenate([transform, np.eye(4)[3:]], axis=0)
|
282 |
+
|
283 |
+
# Flip coordinate system if z component of y-axis is negative
|
284 |
+
if poses_recentered.mean(axis=0)[2, 1] < 0:
|
285 |
+
poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered
|
286 |
+
transform = np.diag(np.array([1, -1, -1, 1])) @ transform
|
287 |
+
|
288 |
+
# Just make sure it's it in the [-1, 1]^3 cube
|
289 |
+
scale_factor = 1. / np.max(np.abs(poses_recentered[:, :3, 3]))
|
290 |
+
poses_recentered[:, :3, 3] *= scale_factor
|
291 |
+
transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform
|
292 |
+
return poses_recentered, transform
|
293 |
+
|
294 |
+
|
295 |
+
def recenter_poses(poses: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
296 |
+
"""Recenter poses around the origin."""
|
297 |
+
cam2world = poses_avg(poses)
|
298 |
+
transform = np.linalg.inv(pad_poses(cam2world))
|
299 |
+
poses = transform @ pad_poses(poses)
|
300 |
+
return unpad_poses(poses), transform
|
301 |
+
|
302 |
+
def generate_ellipse_path(views, n_frames=600, const_speed=True, z_variation=0., z_phase=0.):
|
303 |
+
poses = []
|
304 |
+
for view in views:
|
305 |
+
tmp_view = np.eye(4)
|
306 |
+
tmp_view[:3] = np.concatenate([view.R.T, view.T[:, None]], 1)
|
307 |
+
tmp_view = np.linalg.inv(tmp_view)
|
308 |
+
tmp_view[:, 1:3] *= -1
|
309 |
+
poses.append(tmp_view)
|
310 |
+
poses = np.stack(poses, 0)
|
311 |
+
poses, transform = transform_poses_pca(poses)
|
312 |
+
|
313 |
+
|
314 |
+
# Calculate the focal point for the path (cameras point toward this).
|
315 |
+
center = focus_point_fn(poses)
|
316 |
+
# Path height sits at z=0 (in middle of zero-mean capture pattern).
|
317 |
+
offset = np.array([center[0] , center[1], 0 ])
|
318 |
+
# Calculate scaling for ellipse axes based on input camera positions.
|
319 |
+
sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0)
|
320 |
+
|
321 |
+
# Use ellipse that is symmetric about the focal point in xy.
|
322 |
+
low = -sc + offset
|
323 |
+
high = sc + offset
|
324 |
+
# Optional height variation need not be symmetric
|
325 |
+
z_low = np.percentile((poses[:, :3, 3]), 10, axis=0)
|
326 |
+
z_high = np.percentile((poses[:, :3, 3]), 90, axis=0)
|
327 |
+
|
328 |
+
|
329 |
+
def get_positions(theta):
|
330 |
+
# Interpolate between bounds with trig functions to get ellipse in x-y.
|
331 |
+
# Optionally also interpolate in z to change camera height along path.
|
332 |
+
return np.stack([
|
333 |
+
(low[0] + (high - low)[0] * (np.cos(theta) * .5 + .5)),
|
334 |
+
(low[1] + (high - low)[1] * (np.sin(theta) * .5 + .5)),
|
335 |
+
z_variation * (z_low[2] + (z_high - z_low)[2] *
|
336 |
+
(np.cos(theta + 2 * np.pi * z_phase) * .5 + .5)),
|
337 |
+
], -1)
|
338 |
+
|
339 |
+
theta = np.linspace(0, 2. * np.pi, n_frames + 1, endpoint=True)
|
340 |
+
positions = get_positions(theta)
|
341 |
+
|
342 |
+
if const_speed:
|
343 |
+
# Resample theta angles so that the velocity is closer to constant.
|
344 |
+
lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
|
345 |
+
theta = sample_np(None, theta, np.log(lengths), n_frames + 1)
|
346 |
+
positions = get_positions(theta)
|
347 |
+
|
348 |
+
# Throw away duplicated last position.
|
349 |
+
positions = positions[:-1]
|
350 |
+
|
351 |
+
# Set path's up vector to axis closest to average of input pose up vectors.
|
352 |
+
avg_up = poses[:, :3, 1].mean(0)
|
353 |
+
avg_up = avg_up / np.linalg.norm(avg_up)
|
354 |
+
ind_up = np.argmax(np.abs(avg_up))
|
355 |
+
up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up])
|
356 |
+
# up = normalize(poses[:, :3, 1].sum(0))
|
357 |
+
|
358 |
+
render_poses = []
|
359 |
+
for p in positions:
|
360 |
+
render_pose = np.eye(4)
|
361 |
+
render_pose[:3] = viewmatrix(p - center, up, p)
|
362 |
+
render_pose = np.linalg.inv(transform) @ render_pose
|
363 |
+
render_pose[:3, 1:3] *= -1
|
364 |
+
render_poses.append(np.linalg.inv(render_pose))
|
365 |
+
return render_poses
|
366 |
+
|
367 |
+
|
368 |
+
|
369 |
+
def generate_spiral_path(poses_arr,
|
370 |
+
n_frames: int = 180,
|
371 |
+
n_rots: int = 2,
|
372 |
+
zrate: float = .5) -> np.ndarray:
|
373 |
+
"""Calculates a forward facing spiral path for rendering."""
|
374 |
+
poses = poses_arr[:, :-2].reshape([-1, 3, 5])
|
375 |
+
bounds = poses_arr[:, -2:]
|
376 |
+
fix_rotation = np.array([
|
377 |
+
[0, -1, 0, 0],
|
378 |
+
[1, 0, 0, 0],
|
379 |
+
[0, 0, 1, 0],
|
380 |
+
[0, 0, 0, 1],
|
381 |
+
], dtype=np.float32)
|
382 |
+
poses = poses[:, :3, :4] @ fix_rotation
|
383 |
+
|
384 |
+
scale = 1. / (bounds.min() * .75)
|
385 |
+
poses[:, :3, 3] *= scale
|
386 |
+
bounds *= scale
|
387 |
+
poses, transform = recenter_poses(poses)
|
388 |
+
|
389 |
+
close_depth, inf_depth = bounds.min() * .9, bounds.max() * 5.
|
390 |
+
dt = .75
|
391 |
+
focal = 1 / (((1 - dt) / close_depth + dt / inf_depth))
|
392 |
+
|
393 |
+
# Get radii for spiral path using 90th percentile of camera positions.
|
394 |
+
positions = poses[:, :3, 3]
|
395 |
+
radii = np.percentile(np.abs(positions), 90, 0)
|
396 |
+
radii = np.concatenate([radii, [1.]])
|
397 |
+
|
398 |
+
# Generate poses for spiral path.
|
399 |
+
render_poses = []
|
400 |
+
cam2world = poses_avg(poses)
|
401 |
+
up = poses[:, :3, 1].mean(0)
|
402 |
+
for theta in np.linspace(0., 2. * np.pi * n_rots, n_frames, endpoint=False):
|
403 |
+
t = radii * [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]
|
404 |
+
position = cam2world @ t
|
405 |
+
lookat = cam2world @ [0, 0, -focal, 1.]
|
406 |
+
z_axis = position - lookat
|
407 |
+
render_pose = np.eye(4)
|
408 |
+
render_pose[:3] = viewmatrix(z_axis, up, position)
|
409 |
+
render_pose = np.linalg.inv(transform) @ render_pose
|
410 |
+
render_pose[:3, 1:3] *= -1
|
411 |
+
render_pose[:3, 3] /= scale
|
412 |
+
render_poses.append(np.linalg.inv(render_pose))
|
413 |
+
render_poses = np.stack(render_poses, axis=0)
|
414 |
+
return render_poses
|
415 |
+
|
416 |
+
|
417 |
+
|
418 |
+
def generate_interpolated_path(
|
419 |
+
views,
|
420 |
+
n_interp,
|
421 |
+
spline_degree = 5,
|
422 |
+
smoothness = 0.03,
|
423 |
+
rot_weight = 0.1,
|
424 |
+
lock_up = False,
|
425 |
+
fixed_up_vector = None,
|
426 |
+
lookahead_i = None,
|
427 |
+
frames_per_colmap = None,
|
428 |
+
const_speed = False,
|
429 |
+
n_buffer = None,
|
430 |
+
periodic = False,
|
431 |
+
n_interp_as_total = False,
|
432 |
+
):
|
433 |
+
"""Creates a smooth spline path between input keyframe camera poses.
|
434 |
+
|
435 |
+
Spline is calculated with poses in format (position, lookat-point, up-point).
|
436 |
+
Args:
|
437 |
+
poses: (n, 3, 4) array of input pose keyframes.
|
438 |
+
n_interp: returned path will have n_interp * (n - 1) total poses.
|
439 |
+
spline_degree: polynomial degree of B-spline.
|
440 |
+
smoothness: parameter for spline smoothing, 0 forces exact interpolation.
|
441 |
+
rot_weight: relative weighting of rotation/translation in spline solve.
|
442 |
+
lock_up: if True, forced to use given Up and allow Lookat to vary.
|
443 |
+
fixed_up_vector: replace the interpolated `up` with a fixed vector.
|
444 |
+
lookahead_i: force the look direction to look at the pose `i` frames ahead.
|
445 |
+
frames_per_colmap: conversion factor for the desired average velocity.
|
446 |
+
const_speed: renormalize spline to have constant delta between each pose.
|
447 |
+
n_buffer: Number of buffer frames to insert at the start and end of the
|
448 |
+
path. Helps keep the ends of a spline path straight.
|
449 |
+
periodic: make the spline path periodic (perfect loop).
|
450 |
+
n_interp_as_total: use n_interp as total number of poses in path rather than
|
451 |
+
the number of poses to interpolate between each input.
|
452 |
+
|
453 |
+
Returns:
|
454 |
+
Array of new camera poses with shape (n_interp * (n - 1), 3, 4), or
|
455 |
+
(n_interp, 3, 4) if n_interp_as_total is set.
|
456 |
+
"""
|
457 |
+
poses = []
|
458 |
+
for view in views:
|
459 |
+
tmp_view = np.eye(4)
|
460 |
+
tmp_view[:3] = np.concatenate([view.R.T, view.T[:, None]], 1)
|
461 |
+
tmp_view = np.linalg.inv(tmp_view)
|
462 |
+
tmp_view[:, 1:3] *= -1
|
463 |
+
poses.append(tmp_view)
|
464 |
+
poses = np.stack(poses, 0)
|
465 |
+
|
466 |
+
def poses_to_points(poses, dist):
|
467 |
+
"""Converts from pose matrices to (position, lookat, up) format."""
|
468 |
+
pos = poses[:, :3, -1]
|
469 |
+
lookat = poses[:, :3, -1] - dist * poses[:, :3, 2]
|
470 |
+
up = poses[:, :3, -1] + dist * poses[:, :3, 1]
|
471 |
+
return np.stack([pos, lookat, up], 1)
|
472 |
+
|
473 |
+
def points_to_poses(points):
|
474 |
+
"""Converts from (position, lookat, up) format to pose matrices."""
|
475 |
+
poses = []
|
476 |
+
for i in range(len(points)):
|
477 |
+
pos, lookat_point, up_point = points[i]
|
478 |
+
if lookahead_i is not None:
|
479 |
+
if i + lookahead_i < len(points):
|
480 |
+
lookat = pos - points[i + lookahead_i][0]
|
481 |
+
else:
|
482 |
+
lookat = pos - lookat_point
|
483 |
+
up = (up_point - pos) if fixed_up_vector is None else fixed_up_vector
|
484 |
+
poses.append(viewmatrix(lookat, up, pos))
|
485 |
+
return np.array(poses)
|
486 |
+
|
487 |
+
def insert_buffer_poses(poses, n_buffer):
|
488 |
+
"""Insert extra poses at the start and end of the path."""
|
489 |
+
|
490 |
+
def average_distance(points):
|
491 |
+
distances = np.linalg.norm(points[1:] - points[0:-1], axis=-1)
|
492 |
+
return np.mean(distances)
|
493 |
+
|
494 |
+
def shift(pose, dz):
|
495 |
+
result = np.copy(pose)
|
496 |
+
z = result[:3, 2]
|
497 |
+
z /= np.linalg.norm(z)
|
498 |
+
# Move along forward-backward axis. -z is forward.
|
499 |
+
result[:3, 3] += z * dz
|
500 |
+
return result
|
501 |
+
|
502 |
+
dz = average_distance(poses[:, :3, 3])
|
503 |
+
prefix = np.stack([shift(poses[0], (i + 1) * dz) for i in range(n_buffer)])
|
504 |
+
prefix = prefix[::-1] # reverse order
|
505 |
+
suffix = np.stack(
|
506 |
+
[shift(poses[-1], -(i + 1) * dz) for i in range(n_buffer)]
|
507 |
+
)
|
508 |
+
result = np.concatenate([prefix, poses, suffix])
|
509 |
+
return result
|
510 |
+
|
511 |
+
def remove_buffer_poses(poses, u, n_frames, u_keyframes, n_buffer):
|
512 |
+
u_keyframes = u_keyframes[n_buffer:-n_buffer]
|
513 |
+
mask = (u >= u_keyframes[0]) & (u <= u_keyframes[-1])
|
514 |
+
poses = poses[mask]
|
515 |
+
u = u[mask]
|
516 |
+
n_frames = len(poses)
|
517 |
+
return poses, u, n_frames, u_keyframes
|
518 |
+
|
519 |
+
def interp(points, u, k, s):
|
520 |
+
"""Runs multidimensional B-spline interpolation on the input points."""
|
521 |
+
sh = points.shape
|
522 |
+
pts = np.reshape(points, (sh[0], -1))
|
523 |
+
k = min(k, sh[0] - 1)
|
524 |
+
tck, u_keyframes = scipy.interpolate.splprep(pts.T, k=k, s=s, per=periodic)
|
525 |
+
new_points = np.array(scipy.interpolate.splev(u, tck))
|
526 |
+
new_points = np.reshape(new_points.T, (len(u), sh[1], sh[2]))
|
527 |
+
return new_points, u_keyframes
|
528 |
+
|
529 |
+
|
530 |
+
if n_buffer is not None:
|
531 |
+
poses = insert_buffer_poses(poses, n_buffer)
|
532 |
+
points = poses_to_points(poses, dist=rot_weight)
|
533 |
+
if n_interp_as_total:
|
534 |
+
n_frames = n_interp + 1 # Add extra since final pose is discarded.
|
535 |
+
else:
|
536 |
+
n_frames = n_interp * (points.shape[0] - 1)
|
537 |
+
u = np.linspace(0, 1, n_frames, endpoint=True)
|
538 |
+
new_points, u_keyframes = interp(points, u=u, k=spline_degree, s=smoothness)
|
539 |
+
poses = points_to_poses(new_points)
|
540 |
+
if n_buffer is not None:
|
541 |
+
poses, u, n_frames, u_keyframes = remove_buffer_poses(
|
542 |
+
poses, u, n_frames, u_keyframes, n_buffer
|
543 |
+
)
|
544 |
+
# poses, transform = transform_poses_pca(poses)
|
545 |
+
if frames_per_colmap is not None:
|
546 |
+
# Recalculate the number of frames to achieve desired average velocity.
|
547 |
+
positions = poses[:, :3, -1]
|
548 |
+
lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
|
549 |
+
total_length_colmap = lengths.sum()
|
550 |
+
print('old n_frames:', n_frames)
|
551 |
+
print('total_length_colmap:', total_length_colmap)
|
552 |
+
n_frames = int(total_length_colmap * frames_per_colmap)
|
553 |
+
print('new n_frames:', n_frames)
|
554 |
+
u = np.linspace(
|
555 |
+
np.min(u_keyframes), np.max(u_keyframes), n_frames, endpoint=True
|
556 |
+
)
|
557 |
+
new_points, _ = interp(points, u=u, k=spline_degree, s=smoothness)
|
558 |
+
poses = points_to_poses(new_points)
|
559 |
+
|
560 |
+
if const_speed:
|
561 |
+
# Resample timesteps so that the velocity is nearly constant.
|
562 |
+
positions = poses[:, :3, -1]
|
563 |
+
lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
|
564 |
+
u = sample(None, u, np.log(lengths), n_frames + 1)
|
565 |
+
new_points, _ = interp(points, u=u, k=spline_degree, s=smoothness)
|
566 |
+
poses = points_to_poses(new_points)
|
567 |
+
|
568 |
+
# return poses[:-1], u[:-1], u_keyframes
|
569 |
+
return poses[:-1]
|
570 |
+
|
utils/sh_utils.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 The PlenOctree Authors.
|
2 |
+
# Redistribution and use in source and binary forms, with or without
|
3 |
+
# modification, are permitted provided that the following conditions are met:
|
4 |
+
#
|
5 |
+
# 1. Redistributions of source code must retain the above copyright notice,
|
6 |
+
# this list of conditions and the following disclaimer.
|
7 |
+
#
|
8 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
9 |
+
# this list of conditions and the following disclaimer in the documentation
|
10 |
+
# and/or other materials provided with the distribution.
|
11 |
+
#
|
12 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
13 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
14 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
15 |
+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
16 |
+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
17 |
+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
18 |
+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
19 |
+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
20 |
+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
21 |
+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
22 |
+
# POSSIBILITY OF SUCH DAMAGE.
|
23 |
+
|
24 |
+
import torch
|
25 |
+
|
26 |
+
C0 = 0.28209479177387814
|
27 |
+
C1 = 0.4886025119029199
|
28 |
+
C2 = [
|
29 |
+
1.0925484305920792,
|
30 |
+
-1.0925484305920792,
|
31 |
+
0.31539156525252005,
|
32 |
+
-1.0925484305920792,
|
33 |
+
0.5462742152960396
|
34 |
+
]
|
35 |
+
C3 = [
|
36 |
+
-0.5900435899266435,
|
37 |
+
2.890611442640554,
|
38 |
+
-0.4570457994644658,
|
39 |
+
0.3731763325901154,
|
40 |
+
-0.4570457994644658,
|
41 |
+
1.445305721320277,
|
42 |
+
-0.5900435899266435
|
43 |
+
]
|
44 |
+
C4 = [
|
45 |
+
2.5033429417967046,
|
46 |
+
-1.7701307697799304,
|
47 |
+
0.9461746957575601,
|
48 |
+
-0.6690465435572892,
|
49 |
+
0.10578554691520431,
|
50 |
+
-0.6690465435572892,
|
51 |
+
0.47308734787878004,
|
52 |
+
-1.7701307697799304,
|
53 |
+
0.6258357354491761,
|
54 |
+
]
|
55 |
+
|
56 |
+
|
57 |
+
def eval_sh(deg, sh, dirs):
|
58 |
+
"""
|
59 |
+
Evaluate spherical harmonics at unit directions
|
60 |
+
using hardcoded SH polynomials.
|
61 |
+
Works with torch/np/jnp.
|
62 |
+
... Can be 0 or more batch dimensions.
|
63 |
+
Args:
|
64 |
+
deg: int SH deg. Currently, 0-3 supported
|
65 |
+
sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
|
66 |
+
dirs: jnp.ndarray unit directions [..., 3]
|
67 |
+
Returns:
|
68 |
+
[..., C]
|
69 |
+
"""
|
70 |
+
assert deg <= 4 and deg >= 0
|
71 |
+
coeff = (deg + 1) ** 2
|
72 |
+
assert sh.shape[-1] >= coeff
|
73 |
+
|
74 |
+
result = C0 * sh[..., 0]
|
75 |
+
if deg > 0:
|
76 |
+
x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
|
77 |
+
result = (result -
|
78 |
+
C1 * y * sh[..., 1] +
|
79 |
+
C1 * z * sh[..., 2] -
|
80 |
+
C1 * x * sh[..., 3])
|
81 |
+
|
82 |
+
if deg > 1:
|
83 |
+
xx, yy, zz = x * x, y * y, z * z
|
84 |
+
xy, yz, xz = x * y, y * z, x * z
|
85 |
+
result = (result +
|
86 |
+
C2[0] * xy * sh[..., 4] +
|
87 |
+
C2[1] * yz * sh[..., 5] +
|
88 |
+
C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
|
89 |
+
C2[3] * xz * sh[..., 7] +
|
90 |
+
C2[4] * (xx - yy) * sh[..., 8])
|
91 |
+
|
92 |
+
if deg > 2:
|
93 |
+
result = (result +
|
94 |
+
C3[0] * y * (3 * xx - yy) * sh[..., 9] +
|
95 |
+
C3[1] * xy * z * sh[..., 10] +
|
96 |
+
C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
|
97 |
+
C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
|
98 |
+
C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
|
99 |
+
C3[5] * z * (xx - yy) * sh[..., 14] +
|
100 |
+
C3[6] * x * (xx - 3 * yy) * sh[..., 15])
|
101 |
+
|
102 |
+
if deg > 3:
|
103 |
+
result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
|
104 |
+
C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
|
105 |
+
C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
|
106 |
+
C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
|
107 |
+
C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
|
108 |
+
C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
|
109 |
+
C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
|
110 |
+
C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
|
111 |
+
C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
|
112 |
+
return result
|
113 |
+
|
114 |
+
def RGB2SH(rgb):
|
115 |
+
return (rgb - 0.5) / C0
|
116 |
+
|
117 |
+
def SH2RGB(sh):
|
118 |
+
return sh * C0 + 0.5
|
utils/stepfun.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from internal import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def searchsorted(a, v):
|
7 |
+
"""Find indices where v should be inserted into a to maintain order.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
a: tensor, the sorted reference points that we are scanning to see where v
|
11 |
+
should lie.
|
12 |
+
v: tensor, the query points that we are pretending to insert into a. Does
|
13 |
+
not need to be sorted. All but the last dimensions should match or expand
|
14 |
+
to those of a, the last dimension can differ.
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
(idx_lo, idx_hi), where a[idx_lo] <= v < a[idx_hi], unless v is out of the
|
18 |
+
range [a[0], a[-1]] in which case idx_lo and idx_hi are both the first or
|
19 |
+
last index of a.
|
20 |
+
"""
|
21 |
+
i = torch.arange(a.shape[-1], device=a.device)
|
22 |
+
v_ge_a = v[..., None, :] >= a[..., :, None]
|
23 |
+
idx_lo = torch.max(torch.where(v_ge_a, i[..., :, None], i[..., :1, None]), -2).values
|
24 |
+
idx_hi = torch.min(torch.where(~v_ge_a, i[..., :, None], i[..., -1:, None]), -2).values
|
25 |
+
return idx_lo, idx_hi
|
26 |
+
|
27 |
+
|
28 |
+
def query(tq, t, y, outside_value=0):
|
29 |
+
"""Look up the values of the step function (t, y) at locations tq."""
|
30 |
+
idx_lo, idx_hi = searchsorted(t, tq)
|
31 |
+
yq = torch.where(idx_lo == idx_hi, torch.full_like(idx_hi, outside_value),
|
32 |
+
torch.take_along_dim(y, idx_lo, dim=-1))
|
33 |
+
return yq
|
34 |
+
|
35 |
+
|
36 |
+
def inner_outer(t0, t1, y1):
|
37 |
+
"""Construct inner and outer measures on (t1, y1) for t0."""
|
38 |
+
cy1 = torch.cat([torch.zeros_like(y1[..., :1]),
|
39 |
+
torch.cumsum(y1, dim=-1)],
|
40 |
+
dim=-1)
|
41 |
+
idx_lo, idx_hi = searchsorted(t1, t0)
|
42 |
+
|
43 |
+
cy1_lo = torch.take_along_dim(cy1, idx_lo, dim=-1)
|
44 |
+
cy1_hi = torch.take_along_dim(cy1, idx_hi, dim=-1)
|
45 |
+
|
46 |
+
y0_outer = cy1_hi[..., 1:] - cy1_lo[..., :-1]
|
47 |
+
y0_inner = torch.where(idx_hi[..., :-1] <= idx_lo[..., 1:],
|
48 |
+
cy1_lo[..., 1:] - cy1_hi[..., :-1], torch.zeros_like(idx_lo[..., 1:]))
|
49 |
+
return y0_inner, y0_outer
|
50 |
+
|
51 |
+
|
52 |
+
def lossfun_outer(t, w, t_env, w_env):
|
53 |
+
"""The proposal weight should be an upper envelope on the nerf weight."""
|
54 |
+
eps = torch.finfo(t.dtype).eps
|
55 |
+
# eps = 1e-3
|
56 |
+
|
57 |
+
_, w_outer = inner_outer(t, t_env, w_env)
|
58 |
+
# We assume w_inner <= w <= w_outer. We don't penalize w_inner because it's
|
59 |
+
# more effective to pull w_outer up than it is to push w_inner down.
|
60 |
+
# Scaled half-quadratic loss that gives a constant gradient at w_outer = 0.
|
61 |
+
return (w - w_outer).clamp_min(0) ** 2 / (w + eps)
|
62 |
+
|
63 |
+
|
64 |
+
def weight_to_pdf(t, w):
|
65 |
+
"""Turn a vector of weights that sums to 1 into a PDF that integrates to 1."""
|
66 |
+
eps = torch.finfo(t.dtype).eps
|
67 |
+
return w / (t[..., 1:] - t[..., :-1]).clamp_min(eps)
|
68 |
+
|
69 |
+
|
70 |
+
def pdf_to_weight(t, p):
|
71 |
+
"""Turn a PDF that integrates to 1 into a vector of weights that sums to 1."""
|
72 |
+
return p * (t[..., 1:] - t[..., :-1])
|
73 |
+
|
74 |
+
|
75 |
+
def max_dilate(t, w, dilation, domain=(-torch.inf, torch.inf)):
|
76 |
+
"""Dilate (via max-pooling) a non-negative step function."""
|
77 |
+
t0 = t[..., :-1] - dilation
|
78 |
+
t1 = t[..., 1:] + dilation
|
79 |
+
t_dilate, _ = torch.sort(torch.cat([t, t0, t1], dim=-1), dim=-1)
|
80 |
+
t_dilate = torch.clip(t_dilate, *domain)
|
81 |
+
w_dilate = torch.max(
|
82 |
+
torch.where(
|
83 |
+
(t0[..., None, :] <= t_dilate[..., None])
|
84 |
+
& (t1[..., None, :] > t_dilate[..., None]),
|
85 |
+
w[..., None, :],
|
86 |
+
torch.zeros_like(w[..., None, :]),
|
87 |
+
), dim=-1).values[..., :-1]
|
88 |
+
return t_dilate, w_dilate
|
89 |
+
|
90 |
+
|
91 |
+
def max_dilate_weights(t,
|
92 |
+
w,
|
93 |
+
dilation,
|
94 |
+
domain=(-torch.inf, torch.inf),
|
95 |
+
renormalize=False):
|
96 |
+
"""Dilate (via max-pooling) a set of weights."""
|
97 |
+
eps = torch.finfo(w.dtype).eps
|
98 |
+
# eps = 1e-3
|
99 |
+
|
100 |
+
p = weight_to_pdf(t, w)
|
101 |
+
t_dilate, p_dilate = max_dilate(t, p, dilation, domain=domain)
|
102 |
+
w_dilate = pdf_to_weight(t_dilate, p_dilate)
|
103 |
+
if renormalize:
|
104 |
+
w_dilate /= torch.sum(w_dilate, dim=-1, keepdim=True).clamp_min(eps)
|
105 |
+
return t_dilate, w_dilate
|
106 |
+
|
107 |
+
|
108 |
+
def integrate_weights(w):
|
109 |
+
"""Compute the cumulative sum of w, assuming all weight vectors sum to 1.
|
110 |
+
|
111 |
+
The output's size on the last dimension is one greater than that of the input,
|
112 |
+
because we're computing the integral corresponding to the endpoints of a step
|
113 |
+
function, not the integral of the interior/bin values.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
w: Tensor, which will be integrated along the last axis. This is assumed to
|
117 |
+
sum to 1 along the last axis, and this function will (silently) break if
|
118 |
+
that is not the case.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1
|
122 |
+
"""
|
123 |
+
cw = torch.cumsum(w[..., :-1], dim=-1).clamp_max(1)
|
124 |
+
shape = cw.shape[:-1] + (1,)
|
125 |
+
# Ensure that the CDF starts with exactly 0 and ends with exactly 1.
|
126 |
+
cw0 = torch.cat([torch.zeros(shape, device=cw.device), cw,
|
127 |
+
torch.ones(shape, device=cw.device)], dim=-1)
|
128 |
+
return cw0
|
129 |
+
|
130 |
+
|
131 |
+
def integrate_weights_np(w):
|
132 |
+
"""Compute the cumulative sum of w, assuming all weight vectors sum to 1.
|
133 |
+
|
134 |
+
The output's size on the last dimension is one greater than that of the input,
|
135 |
+
because we're computing the integral corresponding to the endpoints of a step
|
136 |
+
function, not the integral of the interior/bin values.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
w: Tensor, which will be integrated along the last axis. This is assumed to
|
140 |
+
sum to 1 along the last axis, and this function will (silently) break if
|
141 |
+
that is not the case.
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1
|
145 |
+
"""
|
146 |
+
cw = np.minimum(1, np.cumsum(w[..., :-1], axis=-1))
|
147 |
+
shape = cw.shape[:-1] + (1,)
|
148 |
+
# Ensure that the CDF starts with exactly 0 and ends with exactly 1.
|
149 |
+
cw0 = np.concatenate([np.zeros(shape), cw,
|
150 |
+
np.ones(shape)], axis=-1)
|
151 |
+
return cw0
|
152 |
+
|
153 |
+
|
154 |
+
def invert_cdf(u, t, w_logits):
|
155 |
+
"""Invert the CDF defined by (t, w) at the points specified by u in [0, 1)."""
|
156 |
+
# Compute the PDF and CDF for each weight vector.
|
157 |
+
w = torch.softmax(w_logits, dim=-1)
|
158 |
+
cw = integrate_weights(w)
|
159 |
+
# Interpolate into the inverse CDF.
|
160 |
+
t_new = math.sorted_interp(u, cw, t)
|
161 |
+
return t_new
|
162 |
+
|
163 |
+
|
164 |
+
def invert_cdf_np(u, t, w_logits):
|
165 |
+
"""Invert the CDF defined by (t, w) at the points specified by u in [0, 1)."""
|
166 |
+
# Compute the PDF and CDF for each weight vector.
|
167 |
+
w = np.exp(w_logits) / np.exp(w_logits).sum(axis=-1, keepdims=True)
|
168 |
+
cw = integrate_weights_np(w)
|
169 |
+
# Interpolate into the inverse CDF.
|
170 |
+
interp_fn = np.interp
|
171 |
+
t_new = interp_fn(u, cw, t)
|
172 |
+
return t_new
|
173 |
+
|
174 |
+
|
175 |
+
def sample(rand,
|
176 |
+
t,
|
177 |
+
w_logits,
|
178 |
+
num_samples,
|
179 |
+
single_jitter=False,
|
180 |
+
deterministic_center=False):
|
181 |
+
"""Piecewise-Constant PDF sampling from a step function.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
rand: random number generator (or None for `linspace` sampling).
|
185 |
+
t: [..., num_bins + 1], bin endpoint coordinates (must be sorted)
|
186 |
+
w_logits: [..., num_bins], logits corresponding to bin weights
|
187 |
+
num_samples: int, the number of samples.
|
188 |
+
single_jitter: bool, if True, jitter every sample along each ray by the same
|
189 |
+
amount in the inverse CDF. Otherwise, jitter each sample independently.
|
190 |
+
deterministic_center: bool, if False, when `rand` is None return samples that
|
191 |
+
linspace the entire PDF. If True, skip the front and back of the linspace
|
192 |
+
so that the centers of each PDF interval are returned.
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
t_samples: [batch_size, num_samples].
|
196 |
+
"""
|
197 |
+
eps = torch.finfo(t.dtype).eps
|
198 |
+
# eps = 1e-3
|
199 |
+
|
200 |
+
device = t.device
|
201 |
+
|
202 |
+
# Draw uniform samples.
|
203 |
+
if not rand:
|
204 |
+
if deterministic_center:
|
205 |
+
pad = 1 / (2 * num_samples)
|
206 |
+
u = torch.linspace(pad, 1. - pad - eps, num_samples, device=device)
|
207 |
+
else:
|
208 |
+
u = torch.linspace(0, 1. - eps, num_samples, device=device)
|
209 |
+
u = torch.broadcast_to(u, t.shape[:-1] + (num_samples,))
|
210 |
+
else:
|
211 |
+
# `u` is in [0, 1) --- it can be zero, but it can never be 1.
|
212 |
+
u_max = eps + (1 - eps) / num_samples
|
213 |
+
max_jitter = (1 - u_max) / (num_samples - 1) - eps
|
214 |
+
d = 1 if single_jitter else num_samples
|
215 |
+
u = torch.linspace(0, 1 - u_max, num_samples, device=device) + \
|
216 |
+
torch.rand(t.shape[:-1] + (d,), device=device) * max_jitter
|
217 |
+
|
218 |
+
return invert_cdf(u, t, w_logits)
|
219 |
+
|
220 |
+
|
221 |
+
def sample_np(rand,
|
222 |
+
t,
|
223 |
+
w_logits,
|
224 |
+
num_samples,
|
225 |
+
single_jitter=False,
|
226 |
+
deterministic_center=False):
|
227 |
+
"""
|
228 |
+
numpy version of sample()
|
229 |
+
"""
|
230 |
+
eps = np.finfo(np.float32).eps
|
231 |
+
|
232 |
+
# Draw uniform samples.
|
233 |
+
if not rand:
|
234 |
+
if deterministic_center:
|
235 |
+
pad = 1 / (2 * num_samples)
|
236 |
+
u = np.linspace(pad, 1. - pad - eps, num_samples)
|
237 |
+
else:
|
238 |
+
u = np.linspace(0, 1. - eps, num_samples)
|
239 |
+
u = np.broadcast_to(u, t.shape[:-1] + (num_samples,))
|
240 |
+
else:
|
241 |
+
# `u` is in [0, 1) --- it can be zero, but it can never be 1.
|
242 |
+
u_max = eps + (1 - eps) / num_samples
|
243 |
+
max_jitter = (1 - u_max) / (num_samples - 1) - eps
|
244 |
+
d = 1 if single_jitter else num_samples
|
245 |
+
u = np.linspace(0, 1 - u_max, num_samples) + \
|
246 |
+
np.random.rand(*t.shape[:-1], d) * max_jitter
|
247 |
+
|
248 |
+
return invert_cdf_np(u, t, w_logits)
|
249 |
+
|
250 |
+
|
251 |
+
def sample_intervals(rand,
|
252 |
+
t,
|
253 |
+
w_logits,
|
254 |
+
num_samples,
|
255 |
+
single_jitter=False,
|
256 |
+
domain=(-torch.inf, torch.inf)):
|
257 |
+
"""Sample *intervals* (rather than points) from a step function.
|
258 |
+
|
259 |
+
Args:
|
260 |
+
rand: random number generator (or None for `linspace` sampling).
|
261 |
+
t: [..., num_bins + 1], bin endpoint coordinates (must be sorted)
|
262 |
+
w_logits: [..., num_bins], logits corresponding to bin weights
|
263 |
+
num_samples: int, the number of intervals to sample.
|
264 |
+
single_jitter: bool, if True, jitter every sample along each ray by the same
|
265 |
+
amount in the inverse CDF. Otherwise, jitter each sample independently.
|
266 |
+
domain: (minval, maxval), the range of valid values for `t`.
|
267 |
+
|
268 |
+
Returns:
|
269 |
+
t_samples: [batch_size, num_samples].
|
270 |
+
"""
|
271 |
+
if num_samples <= 1:
|
272 |
+
raise ValueError(f'num_samples must be > 1, is {num_samples}.')
|
273 |
+
|
274 |
+
# Sample a set of points from the step function.
|
275 |
+
centers = sample(
|
276 |
+
rand,
|
277 |
+
t,
|
278 |
+
w_logits,
|
279 |
+
num_samples,
|
280 |
+
single_jitter,
|
281 |
+
deterministic_center=True)
|
282 |
+
|
283 |
+
# The intervals we return will span the midpoints of each adjacent sample.
|
284 |
+
mid = (centers[..., 1:] + centers[..., :-1]) / 2
|
285 |
+
|
286 |
+
# Each first/last fencepost is the reflection of the first/last midpoint
|
287 |
+
# around the first/last sampled center. We clamp to the limits of the input
|
288 |
+
# domain, provided by the caller.
|
289 |
+
minval, maxval = domain
|
290 |
+
first = (2 * centers[..., :1] - mid[..., :1]).clamp_min(minval)
|
291 |
+
last = (2 * centers[..., -1:] - mid[..., -1:]).clamp_max(maxval)
|
292 |
+
|
293 |
+
t_samples = torch.cat([first, mid, last], dim=-1)
|
294 |
+
return t_samples
|
295 |
+
|
296 |
+
|
297 |
+
def lossfun_distortion(t, w):
|
298 |
+
"""Compute iint w[i] w[j] |t[i] - t[j]| di dj."""
|
299 |
+
# The loss incurred between all pairs of intervals.
|
300 |
+
ut = (t[..., 1:] + t[..., :-1]) / 2
|
301 |
+
dut = torch.abs(ut[..., :, None] - ut[..., None, :])
|
302 |
+
loss_inter = torch.sum(w * torch.sum(w[..., None, :] * dut, dim=-1), dim=-1)
|
303 |
+
|
304 |
+
# The loss incurred within each individual interval with itself.
|
305 |
+
loss_intra = torch.sum(w ** 2 * (t[..., 1:] - t[..., :-1]), dim=-1) / 3
|
306 |
+
|
307 |
+
return loss_inter + loss_intra
|
308 |
+
|
309 |
+
|
310 |
+
def interval_distortion(t0_lo, t0_hi, t1_lo, t1_hi):
|
311 |
+
"""Compute mean(abs(x-y); x in [t0_lo, t0_hi], y in [t1_lo, t1_hi])."""
|
312 |
+
# Distortion when the intervals do not overlap.
|
313 |
+
d_disjoint = torch.abs((t1_lo + t1_hi) / 2 - (t0_lo + t0_hi) / 2)
|
314 |
+
|
315 |
+
# Distortion when the intervals overlap.
|
316 |
+
d_overlap = (2 *
|
317 |
+
(torch.minimum(t0_hi, t1_hi) ** 3 - torch.maximum(t0_lo, t1_lo) ** 3) +
|
318 |
+
3 * (t1_hi * t0_hi * torch.abs(t1_hi - t0_hi) +
|
319 |
+
t1_lo * t0_lo * torch.abs(t1_lo - t0_lo) + t1_hi * t0_lo *
|
320 |
+
(t0_lo - t1_hi) + t1_lo * t0_hi *
|
321 |
+
(t1_lo - t0_hi))) / (6 * (t0_hi - t0_lo) * (t1_hi - t1_lo))
|
322 |
+
|
323 |
+
# Are the two intervals not overlapping?
|
324 |
+
are_disjoint = (t0_lo > t1_hi) | (t1_lo > t0_hi)
|
325 |
+
|
326 |
+
return torch.where(are_disjoint, d_disjoint, d_overlap)
|
327 |
+
|
328 |
+
|
329 |
+
def weighted_percentile(t, w, ps):
|
330 |
+
"""Compute the weighted percentiles of a step function. w's must sum to 1."""
|
331 |
+
cw = integrate_weights(w)
|
332 |
+
# We want to interpolate into the integrated weights according to `ps`.
|
333 |
+
fn = lambda cw_i, t_i: math.sorted_interp(torch.tensor(ps, device=t.device) / 100, cw_i, t_i)
|
334 |
+
# Vmap fn to an arbitrary number of leading dimensions.
|
335 |
+
cw_mat = cw.reshape([-1, cw.shape[-1]])
|
336 |
+
t_mat = t.reshape([-1, t.shape[-1]])
|
337 |
+
wprctile_mat = fn(cw_mat, t_mat) # TODO
|
338 |
+
wprctile = wprctile_mat.reshape(cw.shape[:-1] + (len(ps),))
|
339 |
+
return wprctile
|
340 |
+
|
341 |
+
|
342 |
+
def resample(t, tp, vp, use_avg=False):
|
343 |
+
"""Resample a step function defined by (tp, vp) into intervals t.
|
344 |
+
|
345 |
+
Args:
|
346 |
+
t: tensor with shape (..., n+1), the endpoints to resample into.
|
347 |
+
tp: tensor with shape (..., m+1), the endpoints of the step function being
|
348 |
+
resampled.
|
349 |
+
vp: tensor with shape (..., m), the values of the step function being
|
350 |
+
resampled.
|
351 |
+
use_avg: bool, if False, return the sum of the step function for each
|
352 |
+
interval in `t`. If True, return the average, weighted by the width of
|
353 |
+
each interval in `t`.
|
354 |
+
eps: float, a small value to prevent division by zero when use_avg=True.
|
355 |
+
|
356 |
+
Returns:
|
357 |
+
v: tensor with shape (..., n), the values of the resampled step function.
|
358 |
+
"""
|
359 |
+
eps = torch.finfo(t.dtype).eps
|
360 |
+
# eps = 1e-3
|
361 |
+
|
362 |
+
if use_avg:
|
363 |
+
wp = torch.diff(tp, dim=-1)
|
364 |
+
v_numer = resample(t, tp, vp * wp, use_avg=False)
|
365 |
+
v_denom = resample(t, tp, wp, use_avg=False)
|
366 |
+
v = v_numer / v_denom.clamp_min(eps)
|
367 |
+
return v
|
368 |
+
|
369 |
+
acc = torch.cumsum(vp, dim=-1)
|
370 |
+
acc0 = torch.cat([torch.zeros(acc.shape[:-1] + (1,), device=acc.device), acc], dim=-1)
|
371 |
+
acc0_resampled = math.sorted_interp(t, tp, acc0) # TODO
|
372 |
+
v = torch.diff(acc0_resampled, dim=-1)
|
373 |
+
return v
|
374 |
+
|
375 |
+
|
376 |
+
def resample_np(t, tp, vp, use_avg=False):
|
377 |
+
"""
|
378 |
+
numpy version of resample
|
379 |
+
"""
|
380 |
+
eps = np.finfo(t.dtype).eps
|
381 |
+
if use_avg:
|
382 |
+
wp = np.diff(tp, axis=-1)
|
383 |
+
v_numer = resample_np(t, tp, vp * wp, use_avg=False)
|
384 |
+
v_denom = resample_np(t, tp, wp, use_avg=False)
|
385 |
+
v = v_numer / np.maximum(eps, v_denom)
|
386 |
+
return v
|
387 |
+
|
388 |
+
acc = np.cumsum(vp, axis=-1)
|
389 |
+
acc0 = np.concatenate([np.zeros(acc.shape[:-1] + (1,)), acc], axis=-1)
|
390 |
+
acc0_resampled = np.vectorize(np.interp, signature='(n),(m),(m)->(n)')(t, tp, acc0)
|
391 |
+
v = np.diff(acc0_resampled, axis=-1)
|
392 |
+
return v
|
393 |
+
|
394 |
+
|
395 |
+
def blur_stepfun(x, y, r):
|
396 |
+
xr, xr_idx = torch.sort(torch.cat([x - r, x + r], dim=-1))
|
397 |
+
y1 = (torch.cat([y, torch.zeros_like(y[..., :1])], dim=-1) -
|
398 |
+
torch.cat([torch.zeros_like(y[..., :1]), y], dim=-1)) / (2 * r)
|
399 |
+
y2 = torch.cat([y1, -y1], dim=-1).take_along_dim(xr_idx[..., :-1], dim=-1)
|
400 |
+
yr = torch.cumsum((xr[..., 1:] - xr[..., :-1]) *
|
401 |
+
torch.cumsum(y2, dim=-1), dim=-1).clamp_min(0)
|
402 |
+
yr = torch.cat([torch.zeros_like(yr[..., :1]), yr], dim=-1)
|
403 |
+
return xr, yr
|
utils/system_utils.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
from errno import EEXIST
|
13 |
+
from os import makedirs, path
|
14 |
+
import os
|
15 |
+
|
16 |
+
def mkdir_p(folder_path):
|
17 |
+
# Creates a directory. equivalent to using mkdir -p on the command line
|
18 |
+
try:
|
19 |
+
makedirs(folder_path)
|
20 |
+
except OSError as exc: # Python >2.5
|
21 |
+
if exc.errno == EEXIST and path.isdir(folder_path):
|
22 |
+
pass
|
23 |
+
else:
|
24 |
+
raise
|
25 |
+
|
26 |
+
def searchForMaxIteration(folder):
|
27 |
+
saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)]
|
28 |
+
return max(saved_iters)
|
utils/trajectories.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import roma
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
def rt_to_mat4(
|
8 |
+
R: torch.Tensor, t: torch.Tensor, s: torch.Tensor | None = None
|
9 |
+
) -> torch.Tensor:
|
10 |
+
"""
|
11 |
+
Args:
|
12 |
+
R (torch.Tensor): (..., 3, 3).
|
13 |
+
t (torch.Tensor): (..., 3).
|
14 |
+
s (torch.Tensor): (...,).
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
torch.Tensor: (..., 4, 4)
|
18 |
+
"""
|
19 |
+
mat34 = torch.cat([R, t[..., None]], dim=-1)
|
20 |
+
if s is None:
|
21 |
+
bottom = (
|
22 |
+
mat34.new_tensor([[0.0, 0.0, 0.0, 1.0]])
|
23 |
+
.reshape((1,) * (mat34.dim() - 2) + (1, 4))
|
24 |
+
.expand(mat34.shape[:-2] + (1, 4))
|
25 |
+
)
|
26 |
+
else:
|
27 |
+
bottom = F.pad(1.0 / s[..., None, None], (3, 0), value=0.0)
|
28 |
+
mat4 = torch.cat([mat34, bottom], dim=-2)
|
29 |
+
return mat4
|
30 |
+
|
31 |
+
def get_avg_w2c(w2cs: torch.Tensor):
|
32 |
+
c2ws = torch.linalg.inv(w2cs)
|
33 |
+
# 1. Compute the center
|
34 |
+
center = c2ws[:, :3, -1].mean(0)
|
35 |
+
# 2. Compute the z axis
|
36 |
+
z = F.normalize(c2ws[:, :3, 2].mean(0), dim=-1)
|
37 |
+
# 3. Compute axis y' (no need to normalize as it's not the final output)
|
38 |
+
y_ = c2ws[:, :3, 1].mean(0) # (3)
|
39 |
+
# 4. Compute the x axis
|
40 |
+
x = F.normalize(torch.cross(y_, z, dim=-1), dim=-1) # (3)
|
41 |
+
# 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
|
42 |
+
y = torch.cross(z, x, dim=-1) # (3)
|
43 |
+
avg_c2w = rt_to_mat4(torch.stack([x, y, z], 1), center)
|
44 |
+
avg_w2c = torch.linalg.inv(avg_c2w)
|
45 |
+
return avg_w2c
|
46 |
+
|
47 |
+
|
48 |
+
# def get_lookat(origins: torch.Tensor, viewdirs: torch.Tensor) -> torch.Tensor:
|
49 |
+
# """Calculate the intersection point of multiple camera rays as the lookat point.
|
50 |
+
|
51 |
+
# Use the center of camera positions as a reference point for the lookat,
|
52 |
+
# then move forward along the average view direction by a certain distance.
|
53 |
+
# """
|
54 |
+
# # Calculate the center of camera positions
|
55 |
+
# center = origins.mean(dim=0)
|
56 |
+
|
57 |
+
# # Calculate average view direction
|
58 |
+
# mean_dir = F.normalize(viewdirs.mean(dim=0), dim=-1)
|
59 |
+
|
60 |
+
# # Calculate average distance to the center point
|
61 |
+
# avg_dist = torch.norm(origins - center, dim=-1).mean()
|
62 |
+
|
63 |
+
# # Move forward along the average view direction
|
64 |
+
# lookat = center + mean_dir * avg_dist
|
65 |
+
|
66 |
+
# return lookat
|
67 |
+
|
68 |
+
def get_lookat(origins: torch.Tensor, viewdirs: torch.Tensor) -> torch.Tensor:
|
69 |
+
"""Triangulate a set of rays to find a single lookat point.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
origins (torch.Tensor): A (N, 3) array of ray origins.
|
73 |
+
viewdirs (torch.Tensor): A (N, 3) array of ray view directions.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
torch.Tensor: A (3,) lookat point.
|
77 |
+
"""
|
78 |
+
|
79 |
+
viewdirs = torch.nn.functional.normalize(viewdirs, dim=-1)
|
80 |
+
eye = torch.eye(3, device=origins.device, dtype=origins.dtype)[None]
|
81 |
+
# Calculate projection matrix I - rr^T
|
82 |
+
I_min_cov = eye - (viewdirs[..., None] * viewdirs[..., None, :])
|
83 |
+
# Compute sum of projections
|
84 |
+
sum_proj = I_min_cov.matmul(origins[..., None]).sum(dim=-3)
|
85 |
+
# Solve for the intersection point using least squares
|
86 |
+
lookat = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
|
87 |
+
# Check NaNs.
|
88 |
+
assert not torch.any(torch.isnan(lookat))
|
89 |
+
return lookat
|
90 |
+
|
91 |
+
|
92 |
+
def get_lookat_w2cs(positions: torch.Tensor, lookat: torch.Tensor, up: torch.Tensor):
|
93 |
+
"""
|
94 |
+
Args:
|
95 |
+
positions: (N, 3) tensor of camera positions
|
96 |
+
lookat: (3,) tensor of lookat point
|
97 |
+
up: (3,) tensor of up vector
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
w2cs: (N, 3, 3) tensor of world to camera rotation matrices
|
101 |
+
"""
|
102 |
+
forward_vectors = F.normalize(lookat - positions, dim=-1)
|
103 |
+
right_vectors = F.normalize(torch.cross(forward_vectors, up[None], dim=-1), dim=-1)
|
104 |
+
down_vectors = F.normalize(
|
105 |
+
torch.cross(forward_vectors, right_vectors, dim=-1), dim=-1
|
106 |
+
)
|
107 |
+
Rs = torch.stack([right_vectors, down_vectors, forward_vectors], dim=-1)
|
108 |
+
w2cs = torch.linalg.inv(rt_to_mat4(Rs, positions))
|
109 |
+
return w2cs
|
110 |
+
|
111 |
+
|
112 |
+
def get_arc_w2cs(
|
113 |
+
ref_w2c: torch.Tensor,
|
114 |
+
lookat: torch.Tensor,
|
115 |
+
up: torch.Tensor,
|
116 |
+
num_frames: int,
|
117 |
+
degree: float,
|
118 |
+
**_,
|
119 |
+
) -> torch.Tensor:
|
120 |
+
ref_position = torch.linalg.inv(ref_w2c)[:3, 3]
|
121 |
+
thetas = (
|
122 |
+
torch.sin(
|
123 |
+
torch.linspace(0.0, torch.pi * 2.0, num_frames + 1, device=ref_w2c.device)[
|
124 |
+
:-1
|
125 |
+
]
|
126 |
+
)
|
127 |
+
* (degree / 2.0)
|
128 |
+
/ 180.0
|
129 |
+
* torch.pi
|
130 |
+
)
|
131 |
+
positions = torch.einsum(
|
132 |
+
"nij,j->ni",
|
133 |
+
roma.rotvec_to_rotmat(thetas[:, None] * up[None]),
|
134 |
+
ref_position - lookat,
|
135 |
+
)
|
136 |
+
return get_lookat_w2cs(positions, lookat, up)
|
137 |
+
|
138 |
+
|
139 |
+
def get_lemniscate_w2cs(
|
140 |
+
ref_w2c: torch.Tensor,
|
141 |
+
lookat: torch.Tensor,
|
142 |
+
up: torch.Tensor,
|
143 |
+
num_frames: int,
|
144 |
+
degree: float,
|
145 |
+
**_,
|
146 |
+
) -> torch.Tensor:
|
147 |
+
ref_c2w = torch.linalg.inv(ref_w2c)
|
148 |
+
a = torch.linalg.norm(ref_c2w[:3, 3] - lookat) * np.tan(degree / 360 * np.pi)
|
149 |
+
# Lemniscate curve in camera space. Starting at the origin.
|
150 |
+
thetas = (
|
151 |
+
torch.linspace(0, 2 * torch.pi, num_frames + 1, device=ref_w2c.device)[:-1]
|
152 |
+
+ torch.pi / 2
|
153 |
+
)
|
154 |
+
|
155 |
+
positions = torch.stack(
|
156 |
+
[
|
157 |
+
a * torch.cos(thetas) / (1 + torch.sin(thetas) ** 2),
|
158 |
+
a * torch.cos(thetas) * torch.sin(thetas) / (1 + torch.sin(thetas) ** 2),
|
159 |
+
torch.zeros(num_frames, device=ref_w2c.device),
|
160 |
+
],
|
161 |
+
dim=-1,
|
162 |
+
)
|
163 |
+
# Transform to world space.
|
164 |
+
positions = torch.einsum(
|
165 |
+
"ij,nj->ni", ref_c2w[:3], F.pad(positions, (0, 1), value=1.0)
|
166 |
+
)
|
167 |
+
return get_lookat_w2cs(positions, lookat, up)
|
168 |
+
|
169 |
+
|
170 |
+
def get_spiral_w2cs(
|
171 |
+
ref_w2c: torch.Tensor,
|
172 |
+
lookat: torch.Tensor,
|
173 |
+
up: torch.Tensor,
|
174 |
+
num_frames: int,
|
175 |
+
rads: float | torch.Tensor,
|
176 |
+
zrate: float,
|
177 |
+
rots: int,
|
178 |
+
**_,
|
179 |
+
) -> torch.Tensor:
|
180 |
+
ref_c2w = torch.linalg.inv(ref_w2c)
|
181 |
+
thetas = torch.linspace(
|
182 |
+
0, 2 * torch.pi * rots, num_frames + 1, device=ref_w2c.device
|
183 |
+
)[:-1]
|
184 |
+
# Spiral curve in camera space. Starting at the origin.
|
185 |
+
if isinstance(rads, torch.Tensor):
|
186 |
+
rads = rads.reshape(-1, 3).to(ref_w2c.device)
|
187 |
+
positions = (
|
188 |
+
torch.stack(
|
189 |
+
[
|
190 |
+
torch.cos(thetas),
|
191 |
+
-torch.sin(thetas),
|
192 |
+
-torch.sin(thetas * zrate),
|
193 |
+
],
|
194 |
+
dim=-1,
|
195 |
+
)
|
196 |
+
* rads
|
197 |
+
)
|
198 |
+
# Transform to world space.
|
199 |
+
positions = torch.einsum(
|
200 |
+
"ij,nj->ni", ref_c2w[:3], F.pad(positions, (0, 1), value=1.0)
|
201 |
+
)
|
202 |
+
return get_lookat_w2cs(positions, lookat, up)
|
203 |
+
|
204 |
+
|
205 |
+
def get_wander_w2cs(ref_w2c, focal_length, num_frames, max_disp, **_):
|
206 |
+
device = ref_w2c.device
|
207 |
+
c2w = np.linalg.inv(ref_w2c.detach().cpu().numpy())
|
208 |
+
max_disp = max_disp
|
209 |
+
|
210 |
+
max_trans = max_disp / focal_length
|
211 |
+
output_poses = []
|
212 |
+
|
213 |
+
for i in range(num_frames):
|
214 |
+
x_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_frames))
|
215 |
+
y_trans = 0.0
|
216 |
+
z_trans = max_trans * np.cos(2.0 * np.pi * float(i) / float(num_frames)) / 2.0
|
217 |
+
|
218 |
+
i_pose = np.concatenate(
|
219 |
+
[
|
220 |
+
np.concatenate(
|
221 |
+
[
|
222 |
+
np.eye(3),
|
223 |
+
np.array([x_trans, y_trans, z_trans])[:, np.newaxis],
|
224 |
+
],
|
225 |
+
axis=1,
|
226 |
+
),
|
227 |
+
np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :],
|
228 |
+
],
|
229 |
+
axis=0,
|
230 |
+
)
|
231 |
+
|
232 |
+
i_pose = np.linalg.inv(i_pose)
|
233 |
+
|
234 |
+
ref_pose = np.concatenate(
|
235 |
+
[c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0
|
236 |
+
)
|
237 |
+
|
238 |
+
render_pose = np.dot(ref_pose, i_pose)
|
239 |
+
output_poses.append(render_pose)
|
240 |
+
output_poses = torch.from_numpy(np.array(output_poses, dtype=np.float32)).to(device)
|
241 |
+
w2cs = torch.linalg.inv(output_poses)
|
242 |
+
|
243 |
+
return w2cs
|
utils/utils_poses/ATE/align_trajectory.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python2
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import utils.utils_poses.ATE.transformations as tfs
|
6 |
+
|
7 |
+
|
8 |
+
def get_best_yaw(C):
|
9 |
+
'''
|
10 |
+
maximize trace(Rz(theta) * C)
|
11 |
+
'''
|
12 |
+
assert C.shape == (3, 3)
|
13 |
+
|
14 |
+
A = C[0, 1] - C[1, 0]
|
15 |
+
B = C[0, 0] + C[1, 1]
|
16 |
+
theta = np.pi / 2 - np.arctan2(B, A)
|
17 |
+
|
18 |
+
return theta
|
19 |
+
|
20 |
+
|
21 |
+
def rot_z(theta):
|
22 |
+
R = tfs.rotation_matrix(theta, [0, 0, 1])
|
23 |
+
R = R[0:3, 0:3]
|
24 |
+
|
25 |
+
return R
|
26 |
+
|
27 |
+
|
28 |
+
def align_umeyama(model, data, known_scale=False, yaw_only=False):
|
29 |
+
"""Implementation of the paper: S. Umeyama, Least-Squares Estimation
|
30 |
+
of Transformation Parameters Between Two Point Patterns,
|
31 |
+
IEEE Trans. Pattern Anal. Mach. Intell., vol. 13, no. 4, 1991.
|
32 |
+
|
33 |
+
model = s * R * data + t
|
34 |
+
|
35 |
+
Input:
|
36 |
+
model -- first trajectory (nx3), numpy array type
|
37 |
+
data -- second trajectory (nx3), numpy array type
|
38 |
+
|
39 |
+
Output:
|
40 |
+
s -- scale factor (scalar)
|
41 |
+
R -- rotation matrix (3x3)
|
42 |
+
t -- translation vector (3x1)
|
43 |
+
t_error -- translational error per point (1xn)
|
44 |
+
|
45 |
+
"""
|
46 |
+
|
47 |
+
# substract mean
|
48 |
+
mu_M = model.mean(0)
|
49 |
+
mu_D = data.mean(0)
|
50 |
+
model_zerocentered = model - mu_M
|
51 |
+
data_zerocentered = data - mu_D
|
52 |
+
n = np.shape(model)[0]
|
53 |
+
|
54 |
+
# correlation
|
55 |
+
C = 1.0/n*np.dot(model_zerocentered.transpose(), data_zerocentered)
|
56 |
+
sigma2 = 1.0/n*np.multiply(data_zerocentered, data_zerocentered).sum()
|
57 |
+
U_svd, D_svd, V_svd = np.linalg.linalg.svd(C)
|
58 |
+
|
59 |
+
D_svd = np.diag(D_svd)
|
60 |
+
V_svd = np.transpose(V_svd)
|
61 |
+
|
62 |
+
S = np.eye(3)
|
63 |
+
if(np.linalg.det(U_svd)*np.linalg.det(V_svd) < 0):
|
64 |
+
S[2, 2] = -1
|
65 |
+
|
66 |
+
if yaw_only:
|
67 |
+
rot_C = np.dot(data_zerocentered.transpose(), model_zerocentered)
|
68 |
+
theta = get_best_yaw(rot_C)
|
69 |
+
R = rot_z(theta)
|
70 |
+
else:
|
71 |
+
R = np.dot(U_svd, np.dot(S, np.transpose(V_svd)))
|
72 |
+
|
73 |
+
if known_scale:
|
74 |
+
s = 1
|
75 |
+
else:
|
76 |
+
s = 1.0/sigma2*np.trace(np.dot(D_svd, S))
|
77 |
+
|
78 |
+
t = mu_M-s*np.dot(R, mu_D)
|
79 |
+
|
80 |
+
return s, R, t
|
utils/utils_poses/ATE/align_utils.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python2
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import utils.utils_poses.ATE.transformations as tfs
|
7 |
+
import utils.utils_poses.ATE.align_trajectory as align
|
8 |
+
|
9 |
+
|
10 |
+
def _getIndices(n_aligned, total_n):
|
11 |
+
if n_aligned == -1:
|
12 |
+
idxs = np.arange(0, total_n)
|
13 |
+
else:
|
14 |
+
assert n_aligned <= total_n and n_aligned >= 1
|
15 |
+
idxs = np.arange(0, n_aligned)
|
16 |
+
return idxs
|
17 |
+
|
18 |
+
|
19 |
+
def alignPositionYawSingle(p_es, p_gt, q_es, q_gt):
|
20 |
+
'''
|
21 |
+
calcualte the 4DOF transformation: yaw R and translation t so that:
|
22 |
+
gt = R * est + t
|
23 |
+
'''
|
24 |
+
|
25 |
+
p_es_0, q_es_0 = p_es[0, :], q_es[0, :]
|
26 |
+
p_gt_0, q_gt_0 = p_gt[0, :], q_gt[0, :]
|
27 |
+
g_rot = tfs.quaternion_matrix(q_gt_0)
|
28 |
+
g_rot = g_rot[0:3, 0:3]
|
29 |
+
est_rot = tfs.quaternion_matrix(q_es_0)
|
30 |
+
est_rot = est_rot[0:3, 0:3]
|
31 |
+
|
32 |
+
C_R = np.dot(est_rot, g_rot.transpose())
|
33 |
+
theta = align.get_best_yaw(C_R)
|
34 |
+
R = align.rot_z(theta)
|
35 |
+
t = p_gt_0 - np.dot(R, p_es_0)
|
36 |
+
|
37 |
+
return R, t
|
38 |
+
|
39 |
+
|
40 |
+
def alignPositionYaw(p_es, p_gt, q_es, q_gt, n_aligned=1):
|
41 |
+
if n_aligned == 1:
|
42 |
+
R, t = alignPositionYawSingle(p_es, p_gt, q_es, q_gt)
|
43 |
+
return R, t
|
44 |
+
else:
|
45 |
+
idxs = _getIndices(n_aligned, p_es.shape[0])
|
46 |
+
est_pos = p_es[idxs, 0:3]
|
47 |
+
gt_pos = p_gt[idxs, 0:3]
|
48 |
+
_, R, t = align.align_umeyama(gt_pos, est_pos, known_scale=True,
|
49 |
+
yaw_only=True) # note the order
|
50 |
+
t = np.array(t)
|
51 |
+
t = t.reshape((3, ))
|
52 |
+
R = np.array(R)
|
53 |
+
return R, t
|
54 |
+
|
55 |
+
|
56 |
+
# align by a SE3 transformation
|
57 |
+
def alignSE3Single(p_es, p_gt, q_es, q_gt):
|
58 |
+
'''
|
59 |
+
Calculate SE3 transformation R and t so that:
|
60 |
+
gt = R * est + t
|
61 |
+
Using only the first poses of est and gt
|
62 |
+
'''
|
63 |
+
|
64 |
+
p_es_0, q_es_0 = p_es[0, :], q_es[0, :]
|
65 |
+
p_gt_0, q_gt_0 = p_gt[0, :], q_gt[0, :]
|
66 |
+
|
67 |
+
g_rot = tfs.quaternion_matrix(q_gt_0)
|
68 |
+
g_rot = g_rot[0:3, 0:3]
|
69 |
+
est_rot = tfs.quaternion_matrix(q_es_0)
|
70 |
+
est_rot = est_rot[0:3, 0:3]
|
71 |
+
|
72 |
+
R = np.dot(g_rot, np.transpose(est_rot))
|
73 |
+
t = p_gt_0 - np.dot(R, p_es_0)
|
74 |
+
|
75 |
+
return R, t
|
76 |
+
|
77 |
+
|
78 |
+
def alignSE3(p_es, p_gt, q_es, q_gt, n_aligned=-1):
|
79 |
+
'''
|
80 |
+
Calculate SE3 transformation R and t so that:
|
81 |
+
gt = R * est + t
|
82 |
+
'''
|
83 |
+
if n_aligned == 1:
|
84 |
+
R, t = alignSE3Single(p_es, p_gt, q_es, q_gt)
|
85 |
+
return R, t
|
86 |
+
else:
|
87 |
+
idxs = _getIndices(n_aligned, p_es.shape[0])
|
88 |
+
est_pos = p_es[idxs, 0:3]
|
89 |
+
gt_pos = p_gt[idxs, 0:3]
|
90 |
+
s, R, t = align.align_umeyama(gt_pos, est_pos,
|
91 |
+
known_scale=True) # note the order
|
92 |
+
t = np.array(t)
|
93 |
+
t = t.reshape((3, ))
|
94 |
+
R = np.array(R)
|
95 |
+
return R, t
|
96 |
+
|
97 |
+
|
98 |
+
# align by similarity transformation
|
99 |
+
def alignSIM3(p_es, p_gt, q_es, q_gt, n_aligned=-1):
|
100 |
+
'''
|
101 |
+
calculate s, R, t so that:
|
102 |
+
gt = R * s * est + t
|
103 |
+
'''
|
104 |
+
idxs = _getIndices(n_aligned, p_es.shape[0])
|
105 |
+
est_pos = p_es[idxs, 0:3]
|
106 |
+
gt_pos = p_gt[idxs, 0:3]
|
107 |
+
s, R, t = align.align_umeyama(gt_pos, est_pos) # note the order
|
108 |
+
return s, R, t
|
109 |
+
|
110 |
+
|
111 |
+
# a general interface
|
112 |
+
def alignTrajectory(p_es, p_gt, q_es, q_gt, method, n_aligned=-1):
|
113 |
+
'''
|
114 |
+
calculate s, R, t so that:
|
115 |
+
gt = R * s * est + t
|
116 |
+
method can be: sim3, se3, posyaw, none;
|
117 |
+
n_aligned: -1 means using all the frames
|
118 |
+
'''
|
119 |
+
assert p_es.shape[1] == 3
|
120 |
+
assert p_gt.shape[1] == 3
|
121 |
+
assert q_es.shape[1] == 4
|
122 |
+
assert q_gt.shape[1] == 4
|
123 |
+
|
124 |
+
s = 1
|
125 |
+
R = None
|
126 |
+
t = None
|
127 |
+
if method == 'sim3':
|
128 |
+
assert n_aligned >= 2 or n_aligned == -1, "sim3 uses at least 2 frames"
|
129 |
+
s, R, t = alignSIM3(p_es, p_gt, q_es, q_gt, n_aligned)
|
130 |
+
elif method == 'se3':
|
131 |
+
R, t = alignSE3(p_es, p_gt, q_es, q_gt, n_aligned)
|
132 |
+
elif method == 'posyaw':
|
133 |
+
R, t = alignPositionYaw(p_es, p_gt, q_es, q_gt, n_aligned)
|
134 |
+
elif method == 'none':
|
135 |
+
R = np.identity(3)
|
136 |
+
t = np.zeros((3, ))
|
137 |
+
else:
|
138 |
+
assert False, 'unknown alignment method'
|
139 |
+
|
140 |
+
return s, R, t
|
141 |
+
|
142 |
+
|
143 |
+
if __name__ == '__main__':
|
144 |
+
pass
|
utils/utils_poses/ATE/compute_trajectory_errors.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python2
|
2 |
+
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import utils.utils_poses.ATE.trajectory_utils as tu
|
7 |
+
import utils.utils_poses.ATE.transformations as tf
|
8 |
+
|
9 |
+
|
10 |
+
def compute_relative_error(p_es, q_es, p_gt, q_gt, T_cm, dist, max_dist_diff,
|
11 |
+
accum_distances=[],
|
12 |
+
scale=1.0):
|
13 |
+
|
14 |
+
if len(accum_distances) == 0:
|
15 |
+
accum_distances = tu.get_distance_from_start(p_gt)
|
16 |
+
comparisons = tu.compute_comparison_indices_length(
|
17 |
+
accum_distances, dist, max_dist_diff)
|
18 |
+
|
19 |
+
n_samples = len(comparisons)
|
20 |
+
print('number of samples = {0} '.format(n_samples))
|
21 |
+
if n_samples < 2:
|
22 |
+
print("Too few samples! Will not compute.")
|
23 |
+
return np.array([]), np.array([]), np.array([]), np.array([]), np.array([]),\
|
24 |
+
np.array([]), np.array([])
|
25 |
+
|
26 |
+
T_mc = np.linalg.inv(T_cm)
|
27 |
+
errors = []
|
28 |
+
for idx, c in enumerate(comparisons):
|
29 |
+
if not c == -1:
|
30 |
+
T_c1 = tu.get_rigid_body_trafo(q_es[idx, :], p_es[idx, :])
|
31 |
+
T_c2 = tu.get_rigid_body_trafo(q_es[c, :], p_es[c, :])
|
32 |
+
T_c1_c2 = np.dot(np.linalg.inv(T_c1), T_c2)
|
33 |
+
T_c1_c2[:3, 3] *= scale
|
34 |
+
|
35 |
+
T_m1 = tu.get_rigid_body_trafo(q_gt[idx, :], p_gt[idx, :])
|
36 |
+
T_m2 = tu.get_rigid_body_trafo(q_gt[c, :], p_gt[c, :])
|
37 |
+
T_m1_m2 = np.dot(np.linalg.inv(T_m1), T_m2)
|
38 |
+
|
39 |
+
T_m1_m2_in_c1 = np.dot(T_cm, np.dot(T_m1_m2, T_mc))
|
40 |
+
T_error_in_c2 = np.dot(np.linalg.inv(T_m1_m2_in_c1), T_c1_c2)
|
41 |
+
T_c2_rot = np.eye(4)
|
42 |
+
T_c2_rot[0:3, 0:3] = T_c2[0:3, 0:3]
|
43 |
+
T_error_in_w = np.dot(T_c2_rot, np.dot(
|
44 |
+
T_error_in_c2, np.linalg.inv(T_c2_rot)))
|
45 |
+
errors.append(T_error_in_w)
|
46 |
+
|
47 |
+
error_trans_norm = []
|
48 |
+
error_trans_perc = []
|
49 |
+
error_yaw = []
|
50 |
+
error_gravity = []
|
51 |
+
e_rot = []
|
52 |
+
e_rot_deg_per_m = []
|
53 |
+
for e in errors:
|
54 |
+
tn = np.linalg.norm(e[0:3, 3])
|
55 |
+
error_trans_norm.append(tn)
|
56 |
+
error_trans_perc.append(tn / dist * 100)
|
57 |
+
ypr_angles = tf.euler_from_matrix(e, 'rzyx')
|
58 |
+
e_rot.append(tu.compute_angle(e))
|
59 |
+
error_yaw.append(abs(ypr_angles[0])*180.0/np.pi)
|
60 |
+
error_gravity.append(
|
61 |
+
np.sqrt(ypr_angles[1]**2+ypr_angles[2]**2)*180.0/np.pi)
|
62 |
+
e_rot_deg_per_m.append(e_rot[-1] / dist)
|
63 |
+
return errors, np.array(error_trans_norm), np.array(error_trans_perc),\
|
64 |
+
np.array(error_yaw), np.array(error_gravity), np.array(e_rot),\
|
65 |
+
np.array(e_rot_deg_per_m)
|
66 |
+
|
67 |
+
|
68 |
+
def compute_absolute_error(p_es_aligned, q_es_aligned, p_gt, q_gt):
|
69 |
+
e_trans_vec = (p_gt-p_es_aligned)
|
70 |
+
e_trans = np.sqrt(np.sum(e_trans_vec**2, 1))
|
71 |
+
|
72 |
+
|
73 |
+
# orientation error
|
74 |
+
e_rot = np.zeros((len(e_trans,)))
|
75 |
+
e_ypr = np.zeros(np.shape(p_es_aligned))
|
76 |
+
for i in range(np.shape(p_es_aligned)[0]):
|
77 |
+
R_we = tf.matrix_from_quaternion(q_es_aligned[i, :])
|
78 |
+
R_wg = tf.matrix_from_quaternion(q_gt[i, :])
|
79 |
+
e_R = np.dot(R_we, np.linalg.inv(R_wg))
|
80 |
+
e_ypr[i, :] = tf.euler_from_matrix(e_R, 'rzyx')
|
81 |
+
e_rot[i] = np.rad2deg(np.linalg.norm(tf.logmap_so3(e_R[:3, :3])))
|
82 |
+
# scale drift
|
83 |
+
motion_gt = np.diff(p_gt, 0)
|
84 |
+
motion_es = np.diff(p_es_aligned, 0)
|
85 |
+
dist_gt = np.sqrt(np.sum(np.multiply(motion_gt, motion_gt), 1))
|
86 |
+
dist_es = np.sqrt(np.sum(np.multiply(motion_es, motion_es), 1))
|
87 |
+
e_scale_perc = np.abs((np.divide(dist_es, dist_gt)-1.0) * 100)
|
88 |
+
# ate = np.sqrt(np.mean(np.asarray(e_trans) ** 2))
|
89 |
+
return e_trans, e_trans_vec, e_rot, e_ypr, e_scale_perc
|
utils/utils_poses/ATE/results_writer.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python2
|
2 |
+
import os
|
3 |
+
# import yaml
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
def compute_statistics(data_vec):
|
8 |
+
stats = dict()
|
9 |
+
if len(data_vec) > 0:
|
10 |
+
stats['rmse'] = float(
|
11 |
+
np.sqrt(np.dot(data_vec, data_vec) / len(data_vec)))
|
12 |
+
stats['mean'] = float(np.mean(data_vec))
|
13 |
+
stats['median'] = float(np.median(data_vec))
|
14 |
+
stats['std'] = float(np.std(data_vec))
|
15 |
+
stats['min'] = float(np.min(data_vec))
|
16 |
+
stats['max'] = float(np.max(data_vec))
|
17 |
+
stats['num_samples'] = int(len(data_vec))
|
18 |
+
else:
|
19 |
+
stats['rmse'] = 0
|
20 |
+
stats['mean'] = 0
|
21 |
+
stats['median'] = 0
|
22 |
+
stats['std'] = 0
|
23 |
+
stats['min'] = 0
|
24 |
+
stats['max'] = 0
|
25 |
+
stats['num_samples'] = 0
|
26 |
+
|
27 |
+
return stats
|
28 |
+
|
29 |
+
|
30 |
+
# def update_and_save_stats(new_stats, label, yaml_filename):
|
31 |
+
# stats = dict()
|
32 |
+
# if os.path.exists(yaml_filename):
|
33 |
+
# stats = yaml.load(open(yaml_filename, 'r'), Loader=yaml.FullLoader)
|
34 |
+
# stats[label] = new_stats
|
35 |
+
#
|
36 |
+
# with open(yaml_filename, 'w') as outfile:
|
37 |
+
# outfile.write(yaml.dump(stats, default_flow_style=False))
|
38 |
+
#
|
39 |
+
# return
|
40 |
+
#
|
41 |
+
#
|
42 |
+
# def compute_and_save_statistics(data_vec, label, yaml_filename):
|
43 |
+
# new_stats = compute_statistics(data_vec)
|
44 |
+
# update_and_save_stats(new_stats, label, yaml_filename)
|
45 |
+
#
|
46 |
+
# return new_stats
|
47 |
+
#
|
48 |
+
#
|
49 |
+
# def write_tex_table(list_values, rows, cols, outfn):
|
50 |
+
# '''
|
51 |
+
# write list_values[row_idx][col_idx] to a table that is ready to be pasted
|
52 |
+
# into latex source
|
53 |
+
#
|
54 |
+
# list_values is a list of row values
|
55 |
+
#
|
56 |
+
# The value should be string of desired format
|
57 |
+
# '''
|
58 |
+
#
|
59 |
+
# assert len(rows) >= 1
|
60 |
+
# assert len(cols) >= 1
|
61 |
+
#
|
62 |
+
# with open(outfn, 'w') as f:
|
63 |
+
# # write header
|
64 |
+
# f.write(' & ')
|
65 |
+
# for col_i in cols[:-1]:
|
66 |
+
# f.write(col_i + ' & ')
|
67 |
+
# f.write(' ' + cols[-1]+'\n')
|
68 |
+
#
|
69 |
+
# # write each row
|
70 |
+
# for row_idx, row_i in enumerate(list_values):
|
71 |
+
# f.write(rows[row_idx] + ' & ')
|
72 |
+
# row_values = list_values[row_idx]
|
73 |
+
# for col_idx in range(len(row_values) - 1):
|
74 |
+
# f.write(row_values[col_idx] + ' & ')
|
75 |
+
# f.write(' ' + row_values[-1]+' \n')
|
utils/utils_poses/ATE/trajectory_utils.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python2
|
2 |
+
"""
|
3 |
+
@author: Christian Forster
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import numpy as np
|
8 |
+
import utils.utils_poses.ATE.transformations as tf
|
9 |
+
|
10 |
+
|
11 |
+
def get_rigid_body_trafo(quat, trans):
|
12 |
+
T = tf.quaternion_matrix(quat)
|
13 |
+
T[0:3, 3] = trans
|
14 |
+
return T
|
15 |
+
|
16 |
+
|
17 |
+
def get_distance_from_start(gt_translation):
|
18 |
+
distances = np.diff(gt_translation[:, 0:3], axis=0)
|
19 |
+
distances = np.sqrt(np.sum(np.multiply(distances, distances), 1))
|
20 |
+
distances = np.cumsum(distances)
|
21 |
+
distances = np.concatenate(([0], distances))
|
22 |
+
return distances
|
23 |
+
|
24 |
+
|
25 |
+
def compute_comparison_indices_length(distances, dist, max_dist_diff):
|
26 |
+
max_idx = len(distances)
|
27 |
+
comparisons = []
|
28 |
+
for idx, d in enumerate(distances):
|
29 |
+
best_idx = -1
|
30 |
+
error = max_dist_diff
|
31 |
+
for i in range(idx, max_idx):
|
32 |
+
if np.abs(distances[i]-(d+dist)) < error:
|
33 |
+
best_idx = i
|
34 |
+
error = np.abs(distances[i] - (d+dist))
|
35 |
+
if best_idx != -1:
|
36 |
+
comparisons.append(best_idx)
|
37 |
+
return comparisons
|
38 |
+
|
39 |
+
|
40 |
+
def compute_angle(transform):
|
41 |
+
"""
|
42 |
+
Compute the rotation angle from a 4x4 homogeneous matrix.
|
43 |
+
"""
|
44 |
+
# an invitation to 3-d vision, p 27
|
45 |
+
return np.arccos(
|
46 |
+
min(1, max(-1, (np.trace(transform[0:3, 0:3]) - 1)/2)))*180.0/np.pi
|
utils/utils_poses/ATE/transformations.py
ADDED
@@ -0,0 +1,1974 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# transformations.py
|
3 |
+
|
4 |
+
# Copyright (c) 2006, Christoph Gohlke
|
5 |
+
# Copyright (c) 2006-2009, The Regents of the University of California
|
6 |
+
# All rights reserved.
|
7 |
+
#
|
8 |
+
# Redistribution and use in source and binary forms, with or without
|
9 |
+
# modification, are permitted provided that the following conditions are met:
|
10 |
+
#
|
11 |
+
# * Redistributions of source code must retain the above copyright
|
12 |
+
# notice, this list of conditions and the following disclaimer.
|
13 |
+
# * Redistributions in binary form must reproduce the above copyright
|
14 |
+
# notice, this list of conditions and the following disclaimer in the
|
15 |
+
# documentation and/or other materials provided with the distribution.
|
16 |
+
# * Neither the name of the copyright holders nor the names of any
|
17 |
+
# contributors may be used to endorse or promote products derived
|
18 |
+
# from this software without specific prior written permission.
|
19 |
+
#
|
20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
23 |
+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
24 |
+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
25 |
+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
26 |
+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
27 |
+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
28 |
+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
29 |
+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
30 |
+
# POSSIBILITY OF SUCH DAMAGE.
|
31 |
+
|
32 |
+
"""Homogeneous Transformation Matrices and Quaternions.
|
33 |
+
|
34 |
+
A library for calculating 4x4 matrices for translating, rotating, reflecting,
|
35 |
+
scaling, shearing, projecting, orthogonalizing, and superimposing arrays of
|
36 |
+
3D homogeneous coordinates as well as for converting between rotation matrices,
|
37 |
+
Euler angles, and quaternions. Also includes an Arcball control object and
|
38 |
+
functions to decompose transformation matrices.
|
39 |
+
|
40 |
+
:Authors:
|
41 |
+
`Christoph Gohlke <http://www.lfd.uci.edu/~gohlke/>`__,
|
42 |
+
Laboratory for Fluorescence Dynamics, University of California, Irvine
|
43 |
+
|
44 |
+
:Version: 20090418
|
45 |
+
|
46 |
+
Requirements
|
47 |
+
------------
|
48 |
+
|
49 |
+
* `Python 2.6 <http://www.python.org>`__
|
50 |
+
* `Numpy 1.3 <http://numpy.scipy.org>`__
|
51 |
+
* `transformations.c 20090418 <http://www.lfd.uci.edu/~gohlke/>`__
|
52 |
+
(optional implementation of some functions in C)
|
53 |
+
|
54 |
+
Notes
|
55 |
+
-----
|
56 |
+
|
57 |
+
Matrices (M) can be inverted using numpy.linalg.inv(M), concatenated using
|
58 |
+
numpy.dot(M0, M1), or used to transform homogeneous coordinates (v) using
|
59 |
+
numpy.dot(M, v) for shape (4, \*) "point of arrays", respectively
|
60 |
+
numpy.dot(v, M.T) for shape (\*, 4) "array of points".
|
61 |
+
|
62 |
+
Calculations are carried out with numpy.float64 precision.
|
63 |
+
|
64 |
+
This Python implementation is not optimized for speed.
|
65 |
+
|
66 |
+
Vector, point, quaternion, and matrix function arguments are expected to be
|
67 |
+
"array like", i.e. tuple, list, or numpy arrays.
|
68 |
+
|
69 |
+
Return types are numpy arrays unless specified otherwise.
|
70 |
+
|
71 |
+
Angles are in radians unless specified otherwise.
|
72 |
+
|
73 |
+
Quaternions ix+jy+kz+w are represented as [x, y, z, w].
|
74 |
+
|
75 |
+
Use the transpose of transformation matrices for OpenGL glMultMatrixd().
|
76 |
+
|
77 |
+
A triple of Euler angles can be applied/interpreted in 24 ways, which can
|
78 |
+
be specified using a 4 character string or encoded 4-tuple:
|
79 |
+
|
80 |
+
*Axes 4-string*: e.g. 'sxyz' or 'ryxy'
|
81 |
+
|
82 |
+
- first character : rotations are applied to 's'tatic or 'r'otating frame
|
83 |
+
- remaining characters : successive rotation axis 'x', 'y', or 'z'
|
84 |
+
|
85 |
+
*Axes 4-tuple*: e.g. (0, 0, 0, 0) or (1, 1, 1, 1)
|
86 |
+
|
87 |
+
- inner axis: code of axis ('x':0, 'y':1, 'z':2) of rightmost matrix.
|
88 |
+
- parity : even (0) if inner axis 'x' is followed by 'y', 'y' is followed
|
89 |
+
by 'z', or 'z' is followed by 'x'. Otherwise odd (1).
|
90 |
+
- repetition : first and last axis are same (1) or different (0).
|
91 |
+
- frame : rotations are applied to static (0) or rotating (1) frame.
|
92 |
+
|
93 |
+
References
|
94 |
+
----------
|
95 |
+
|
96 |
+
(1) Matrices and transformations. Ronald Goldman.
|
97 |
+
In "Graphics Gems I", pp 472-475. Morgan Kaufmann, 1990.
|
98 |
+
(2) More matrices and transformations: shear and pseudo-perspective.
|
99 |
+
Ronald Goldman. In "Graphics Gems II", pp 320-323. Morgan Kaufmann, 1991.
|
100 |
+
(3) Decomposing a matrix into simple transformations. Spencer Thomas.
|
101 |
+
In "Graphics Gems II", pp 320-323. Morgan Kaufmann, 1991.
|
102 |
+
(4) Recovering the data from the transformation matrix. Ronald Goldman.
|
103 |
+
In "Graphics Gems II", pp 324-331. Morgan Kaufmann, 1991.
|
104 |
+
(5) Euler angle conversion. Ken Shoemake.
|
105 |
+
In "Graphics Gems IV", pp 222-229. Morgan Kaufmann, 1994.
|
106 |
+
(6) Arcball rotation control. Ken Shoemake.
|
107 |
+
In "Graphics Gems IV", pp 175-192. Morgan Kaufmann, 1994.
|
108 |
+
(7) Representing attitude: Euler angles, unit quaternions, and rotation
|
109 |
+
vectors. James Diebel. 2006.
|
110 |
+
(8) A discussion of the solution for the best rotation to relate two sets
|
111 |
+
of vectors. W Kabsch. Acta Cryst. 1978. A34, 827-828.
|
112 |
+
(9) Closed-form solution of absolute orientation using unit quaternions.
|
113 |
+
BKP Horn. J Opt Soc Am A. 1987. 4(4), 629-642.
|
114 |
+
(10) Quaternions. Ken Shoemake.
|
115 |
+
http://www.sfu.ca/~jwa3/cmpt461/files/quatut.pdf
|
116 |
+
(11) From quaternion to matrix and back. JMP van Waveren. 2005.
|
117 |
+
http://www.intel.com/cd/ids/developer/asmo-na/eng/293748.htm
|
118 |
+
(12) Uniform random rotations. Ken Shoemake.
|
119 |
+
In "Graphics Gems III", pp 124-132. Morgan Kaufmann, 1992.
|
120 |
+
|
121 |
+
|
122 |
+
Examples
|
123 |
+
--------
|
124 |
+
|
125 |
+
>>> alpha, beta, gamma = 0.123, -1.234, 2.345
|
126 |
+
>>> origin, xaxis, yaxis, zaxis = (0, 0, 0), (1, 0, 0), (0, 1, 0), (0, 0, 1)
|
127 |
+
>>> I = identity_matrix()
|
128 |
+
>>> Rx = rotation_matrix(alpha, xaxis)
|
129 |
+
>>> Ry = rotation_matrix(beta, yaxis)
|
130 |
+
>>> Rz = rotation_matrix(gamma, zaxis)
|
131 |
+
>>> R = concatenate_matrices(Rx, Ry, Rz)
|
132 |
+
>>> euler = euler_from_matrix(R, 'rxyz')
|
133 |
+
>>> numpy.allclose([alpha, beta, gamma], euler)
|
134 |
+
True
|
135 |
+
>>> Re = euler_matrix(alpha, beta, gamma, 'rxyz')
|
136 |
+
>>> is_same_transform(R, Re)
|
137 |
+
True
|
138 |
+
>>> al, be, ga = euler_from_matrix(Re, 'rxyz')
|
139 |
+
>>> is_same_transform(Re, euler_matrix(al, be, ga, 'rxyz'))
|
140 |
+
True
|
141 |
+
>>> qx = quaternion_about_axis(alpha, xaxis)
|
142 |
+
>>> qy = quaternion_about_axis(beta, yaxis)
|
143 |
+
>>> qz = quaternion_about_axis(gamma, zaxis)
|
144 |
+
>>> q = quaternion_multiply(qx, qy)
|
145 |
+
>>> q = quaternion_multiply(q, qz)
|
146 |
+
>>> Rq = quaternion_matrix(q)
|
147 |
+
>>> is_same_transform(R, Rq)
|
148 |
+
True
|
149 |
+
>>> S = scale_matrix(1.23, origin)
|
150 |
+
>>> T = translation_matrix((1, 2, 3))
|
151 |
+
>>> Z = shear_matrix(beta, xaxis, origin, zaxis)
|
152 |
+
>>> R = random_rotation_matrix(numpy.random.rand(3))
|
153 |
+
>>> M = concatenate_matrices(T, R, Z, S)
|
154 |
+
>>> scale, shear, angles, trans, persp = decompose_matrix(M)
|
155 |
+
>>> numpy.allclose(scale, 1.23)
|
156 |
+
True
|
157 |
+
>>> numpy.allclose(trans, (1, 2, 3))
|
158 |
+
True
|
159 |
+
>>> numpy.allclose(shear, (0, math.tan(beta), 0))
|
160 |
+
True
|
161 |
+
>>> is_same_transform(R, euler_matrix(axes='sxyz', *angles))
|
162 |
+
True
|
163 |
+
>>> M1 = compose_matrix(scale, shear, angles, trans, persp)
|
164 |
+
>>> is_same_transform(M, M1)
|
165 |
+
True
|
166 |
+
|
167 |
+
"""
|
168 |
+
|
169 |
+
from __future__ import division
|
170 |
+
|
171 |
+
import warnings
|
172 |
+
import math
|
173 |
+
|
174 |
+
import numpy
|
175 |
+
|
176 |
+
# Documentation in HTML format can be generated with Epydoc
|
177 |
+
__docformat__ = "restructuredtext en"
|
178 |
+
|
179 |
+
|
180 |
+
def skew(v):
|
181 |
+
"""Returns the skew-symmetric matrix of a vector
|
182 |
+
cfo, 2015/08/13
|
183 |
+
|
184 |
+
"""
|
185 |
+
return numpy.array([[0, -v[2], v[1]],
|
186 |
+
[v[2], 0, -v[0]],
|
187 |
+
[-v[1], v[0], 0]], dtype=numpy.float64)
|
188 |
+
|
189 |
+
|
190 |
+
def unskew(R):
|
191 |
+
"""Returns the coordinates of a skew-symmetric matrix
|
192 |
+
cfo, 2015/08/13
|
193 |
+
|
194 |
+
"""
|
195 |
+
return numpy.array([R[2, 1], R[0, 2], R[1, 0]], dtype=numpy.float64)
|
196 |
+
|
197 |
+
|
198 |
+
def first_order_rotation(rotvec):
|
199 |
+
"""First order approximation of a rotation: I + skew(rotvec)
|
200 |
+
cfo, 2015/08/13
|
201 |
+
|
202 |
+
"""
|
203 |
+
R = numpy.zeros((3, 3), dtype=numpy.float64)
|
204 |
+
R[0, 0] = 1.0
|
205 |
+
R[1, 0] = rotvec[2]
|
206 |
+
R[2, 0] = -rotvec[1]
|
207 |
+
R[0, 1] = -rotvec[2]
|
208 |
+
R[1, 1] = 1.0
|
209 |
+
R[2, 1] = rotvec[0]
|
210 |
+
R[0, 2] = rotvec[1]
|
211 |
+
R[1, 2] = -rotvec[0]
|
212 |
+
R[2, 2] = 1.0
|
213 |
+
return R
|
214 |
+
|
215 |
+
|
216 |
+
def axis_angle(axis, theta):
|
217 |
+
"""Compute a rotation matrix from an axis and an angle.
|
218 |
+
Returns 3x3 Matrix.
|
219 |
+
Is the same as transformations.rotation_matrix(theta, axis).
|
220 |
+
cfo, 2015/08/13
|
221 |
+
|
222 |
+
"""
|
223 |
+
if theta*theta > _EPS:
|
224 |
+
wx = axis[0]
|
225 |
+
wy = axis[1]
|
226 |
+
wz = axis[2]
|
227 |
+
costheta = numpy.cos(theta)
|
228 |
+
sintheta = numpy.sin(theta)
|
229 |
+
c_1 = 1.0 - costheta
|
230 |
+
wx_sintheta = wx * sintheta
|
231 |
+
wy_sintheta = wy * sintheta
|
232 |
+
wz_sintheta = wz * sintheta
|
233 |
+
C00 = c_1 * wx * wx
|
234 |
+
C01 = c_1 * wx * wy
|
235 |
+
C02 = c_1 * wx * wz
|
236 |
+
C11 = c_1 * wy * wy
|
237 |
+
C12 = c_1 * wy * wz
|
238 |
+
C22 = c_1 * wz * wz
|
239 |
+
R = numpy.zeros((3, 3), dtype=numpy.float64)
|
240 |
+
R[0, 0] = costheta + C00
|
241 |
+
R[1, 0] = wz_sintheta + C01
|
242 |
+
R[2, 0] = -wy_sintheta + C02
|
243 |
+
R[0, 1] = -wz_sintheta + C01
|
244 |
+
R[1, 1] = costheta + C11
|
245 |
+
R[2, 1] = wx_sintheta + C12
|
246 |
+
R[0, 2] = wy_sintheta + C02
|
247 |
+
R[1, 2] = -wx_sintheta + C12
|
248 |
+
R[2, 2] = costheta + C22
|
249 |
+
return R
|
250 |
+
else:
|
251 |
+
return first_order_rotation(axis*theta)
|
252 |
+
|
253 |
+
|
254 |
+
def expmap_so3(rotvec):
|
255 |
+
"""Exponential map at identity.
|
256 |
+
Create a rotation from canonical coordinates using Rodrigues' formula.
|
257 |
+
cfo, 2015/08/13
|
258 |
+
|
259 |
+
"""
|
260 |
+
theta = numpy.linalg.norm(rotvec)
|
261 |
+
axis = rotvec/theta
|
262 |
+
return axis_angle(axis, theta)
|
263 |
+
|
264 |
+
|
265 |
+
def logmap_so3(R):
|
266 |
+
"""Logmap at the identity.
|
267 |
+
Returns canonical coordinates of rotation.
|
268 |
+
cfo, 2015/08/13
|
269 |
+
|
270 |
+
"""
|
271 |
+
R11 = R[0, 0]
|
272 |
+
R12 = R[0, 1]
|
273 |
+
R13 = R[0, 2]
|
274 |
+
R21 = R[1, 0]
|
275 |
+
R22 = R[1, 1]
|
276 |
+
R23 = R[1, 2]
|
277 |
+
R31 = R[2, 0]
|
278 |
+
R32 = R[2, 1]
|
279 |
+
R33 = R[2, 2]
|
280 |
+
tr = numpy.trace(R)
|
281 |
+
omega = numpy.empty((3,), dtype=numpy.float64)
|
282 |
+
|
283 |
+
# when trace == -1, i.e., when theta = +-pi, +-3pi, +-5pi, we do something
|
284 |
+
# special
|
285 |
+
if(numpy.abs(tr + 1.0) < 1e-10):
|
286 |
+
if(numpy.abs(R33 + 1.0) > 1e-10):
|
287 |
+
omega = (numpy.pi / numpy.sqrt(2.0 + 2.0 * R33)) * \
|
288 |
+
numpy.array([R13, R23, 1.0+R33])
|
289 |
+
elif(numpy.abs(R22 + 1.0) > 1e-10):
|
290 |
+
omega = (numpy.pi / numpy.sqrt(2.0 + 2.0 * R22)) * \
|
291 |
+
numpy.array([R12, 1.0+R22, R32])
|
292 |
+
else:
|
293 |
+
omega = (numpy.pi / numpy.sqrt(2.0 + 2.0 * R11)) * \
|
294 |
+
numpy.array([1.0+R11, R21, R31])
|
295 |
+
else:
|
296 |
+
magnitude = 1.0
|
297 |
+
tr_3 = tr - 3.0
|
298 |
+
if tr_3 < -1e-7:
|
299 |
+
theta = numpy.arccos((tr - 1.0) / 2.0)
|
300 |
+
magnitude = theta / (2.0 * numpy.sin(theta))
|
301 |
+
else:
|
302 |
+
# when theta near 0, +-2pi, +-4pi, etc. (trace near 3.0)
|
303 |
+
# use Taylor expansion: theta \approx 1/2-(t-3)/12 + O((t-3)^2)
|
304 |
+
magnitude = 0.5 - tr_3 * tr_3 / 12.0
|
305 |
+
|
306 |
+
omega = magnitude * numpy.array([R32 - R23, R13 - R31, R21 - R12])
|
307 |
+
|
308 |
+
return omega
|
309 |
+
|
310 |
+
|
311 |
+
|
312 |
+
def right_jacobian_so3(rotvec):
|
313 |
+
"""Right Jacobian for Exponential map in SO(3)
|
314 |
+
Equation (10.86) and following equations in G.S. Chirikjian, "Stochastic
|
315 |
+
Models, Information Theory, and Lie Groups", Volume 2, 2008.
|
316 |
+
|
317 |
+
> expmap_so3(thetahat + omega) \approx expmap_so3(thetahat) * expmap_so3(Jr * omega)
|
318 |
+
where Jr = right_jacobian_so3(thetahat);
|
319 |
+
This maps a perturbation in the tangent space (omega) to a perturbation
|
320 |
+
on the manifold (expmap_so3(Jr * omega))
|
321 |
+
cfo, 2015/08/13
|
322 |
+
|
323 |
+
"""
|
324 |
+
|
325 |
+
theta2 = numpy.dot(rotvec, rotvec)
|
326 |
+
if theta2 <= _EPS:
|
327 |
+
return numpy.identity(3, dtype=numpy.float64)
|
328 |
+
else:
|
329 |
+
theta = numpy.sqrt(theta2)
|
330 |
+
Y = skew(rotvec) / theta
|
331 |
+
I_3x3 = numpy.identity(3, dtype=numpy.float64)
|
332 |
+
J_r = I_3x3 - ((1.0 - numpy.cos(theta)) / theta) * Y + \
|
333 |
+
(1.0 - numpy.sin(theta) / theta) * numpy.dot(Y, Y)
|
334 |
+
return J_r
|
335 |
+
|
336 |
+
|
337 |
+
def S_inv_eulerZYX_body(euler_coordinates):
|
338 |
+
""" Relates angular rates w to changes in eulerZYX coordinates.
|
339 |
+
dot(euler) = S^-1(euler_coordinates) * omega
|
340 |
+
Also called: rotation-rate matrix. (E in Lupton paper)
|
341 |
+
cfo, 2015/08/13
|
342 |
+
|
343 |
+
"""
|
344 |
+
y = euler_coordinates[1]
|
345 |
+
z = euler_coordinates[2]
|
346 |
+
E = numpy.zeros((3, 3))
|
347 |
+
E[0, 1] = numpy.sin(z)/numpy.cos(y)
|
348 |
+
E[0, 2] = numpy.cos(z)/numpy.cos(y)
|
349 |
+
E[1, 1] = numpy.cos(z)
|
350 |
+
E[1, 2] = -numpy.sin(z)
|
351 |
+
E[2, 0] = 1.0
|
352 |
+
E[2, 1] = numpy.sin(z)*numpy.sin(y)/numpy.cos(y)
|
353 |
+
E[2, 2] = numpy.cos(z)*numpy.sin(y)/numpy.cos(y)
|
354 |
+
return E
|
355 |
+
|
356 |
+
|
357 |
+
def S_inv_eulerZYX_body_deriv(euler_coordinates, omega):
|
358 |
+
""" Compute dE(euler_coordinates)*omega/deuler_coordinates
|
359 |
+
cfo, 2015/08/13
|
360 |
+
|
361 |
+
"""
|
362 |
+
|
363 |
+
y = euler_coordinates[1]
|
364 |
+
z = euler_coordinates[2]
|
365 |
+
|
366 |
+
"""
|
367 |
+
w1 = omega[0]; w2 = omega[1]; w3 = omega[2]
|
368 |
+
J = numpy.zeros((3,3))
|
369 |
+
J[0,0] = 0
|
370 |
+
J[0,1] = math.tan(y) / math.cos(y) * (math.sin(z) * w2 + math.cos(z) * w3)
|
371 |
+
J[0,2] = w2/math.cos(y)*math.cos(z) - w3/math.cos(y)*math.sin(z)
|
372 |
+
J[1,0] = 0
|
373 |
+
J[1,1] = 0
|
374 |
+
J[1,2] = -w2*math.sin(z) - w3*math.cos(z)
|
375 |
+
J[2,0] = w1
|
376 |
+
J[2,1] = 1.0/math.cos(y)**2 * (w2 * math.sin(z) + w3 * math.cos(z))
|
377 |
+
J[2,2] = w2*math.tan(y)*math.cos(z) - w3*math.tan(y)*math.sin(z)
|
378 |
+
|
379 |
+
"""
|
380 |
+
|
381 |
+
# second version, x = psi, y = theta, z = phi
|
382 |
+
# J_x = numpy.zeros((3,3))
|
383 |
+
J_y = numpy.zeros((3, 3))
|
384 |
+
J_z = numpy.zeros((3, 3))
|
385 |
+
|
386 |
+
# dE^-1/dtheta
|
387 |
+
J_y[0, 1] = math.tan(y)/math.cos(y)*math.sin(z)
|
388 |
+
J_y[0, 2] = math.tan(y)/math.cos(y)*math.cos(z)
|
389 |
+
J_y[2, 1] = math.sin(z)/(math.cos(y))**2
|
390 |
+
J_y[2, 2] = math.cos(z)/(math.cos(y))**2
|
391 |
+
|
392 |
+
# dE^-1/dphi
|
393 |
+
J_z[0, 1] = math.cos(z)/math.cos(y)
|
394 |
+
J_z[0, 2] = -math.sin(z)/math.cos(y)
|
395 |
+
J_z[1, 1] = -math.sin(z)
|
396 |
+
J_z[1, 2] = -math.cos(z)
|
397 |
+
J_z[2, 1] = math.cos(z)*math.tan(y)
|
398 |
+
J_z[2, 2] = -math.sin(z)*math.tan(y)
|
399 |
+
|
400 |
+
J = numpy.zeros((3, 3))
|
401 |
+
J[:, 1] = numpy.dot(J_y, omega)
|
402 |
+
J[:, 2] = numpy.dot(J_z, omega)
|
403 |
+
|
404 |
+
return J
|
405 |
+
|
406 |
+
|
407 |
+
def identity_matrix():
|
408 |
+
"""Return 4x4 identity/unit matrix.
|
409 |
+
|
410 |
+
>>> I = identity_matrix()
|
411 |
+
>>> numpy.allclose(I, numpy.dot(I, I))
|
412 |
+
True
|
413 |
+
>>> numpy.sum(I), numpy.trace(I)
|
414 |
+
(4.0, 4.0)
|
415 |
+
>>> numpy.allclose(I, numpy.identity(4, dtype=numpy.float64))
|
416 |
+
True
|
417 |
+
|
418 |
+
"""
|
419 |
+
return numpy.identity(4, dtype=numpy.float64)
|
420 |
+
|
421 |
+
|
422 |
+
def translation_matrix(direction):
|
423 |
+
"""Return matrix to translate by direction vector.
|
424 |
+
|
425 |
+
>>> v = numpy.random.random(3) - 0.5
|
426 |
+
>>> numpy.allclose(v, translation_matrix(v)[:3, 3])
|
427 |
+
True
|
428 |
+
|
429 |
+
"""
|
430 |
+
M = numpy.identity(4)
|
431 |
+
M[:3, 3] = direction[:3]
|
432 |
+
return M
|
433 |
+
|
434 |
+
|
435 |
+
def translation_from_matrix(matrix):
|
436 |
+
"""Return translation vector from translation matrix.
|
437 |
+
|
438 |
+
>>> v0 = numpy.random.random(3) - 0.5
|
439 |
+
>>> v1 = translation_from_matrix(translation_matrix(v0))
|
440 |
+
>>> numpy.allclose(v0, v1)
|
441 |
+
True
|
442 |
+
|
443 |
+
"""
|
444 |
+
return numpy.array(matrix, copy=False)[:3, 3].copy()
|
445 |
+
|
446 |
+
|
447 |
+
def convert_3x3_to_4x4(matrix_3x3):
|
448 |
+
M = numpy.identity(4)
|
449 |
+
M[:3, :3] = matrix_3x3
|
450 |
+
return M
|
451 |
+
|
452 |
+
|
453 |
+
def reflection_matrix(point, normal):
|
454 |
+
"""Return matrix to mirror at plane defined by point and normal vector.
|
455 |
+
|
456 |
+
>>> v0 = numpy.random.random(4) - 0.5
|
457 |
+
>>> v0[3] = 1.0
|
458 |
+
>>> v1 = numpy.random.random(3) - 0.5
|
459 |
+
>>> R = reflection_matrix(v0, v1)
|
460 |
+
>>> numpy.allclose(2., numpy.trace(R))
|
461 |
+
True
|
462 |
+
>>> numpy.allclose(v0, numpy.dot(R, v0))
|
463 |
+
True
|
464 |
+
>>> v2 = v0.copy()
|
465 |
+
>>> v2[:3] += v1
|
466 |
+
>>> v3 = v0.copy()
|
467 |
+
>>> v2[:3] -= v1
|
468 |
+
>>> numpy.allclose(v2, numpy.dot(R, v3))
|
469 |
+
True
|
470 |
+
|
471 |
+
"""
|
472 |
+
normal = unit_vector(normal[:3])
|
473 |
+
M = numpy.identity(4)
|
474 |
+
M[:3, :3] -= 2.0 * numpy.outer(normal, normal)
|
475 |
+
M[:3, 3] = (2.0 * numpy.dot(point[:3], normal)) * normal
|
476 |
+
return M
|
477 |
+
|
478 |
+
|
479 |
+
def reflection_from_matrix(matrix):
|
480 |
+
"""Return mirror plane point and normal vector from reflection matrix.
|
481 |
+
|
482 |
+
>>> v0 = numpy.random.random(3) - 0.5
|
483 |
+
>>> v1 = numpy.random.random(3) - 0.5
|
484 |
+
>>> M0 = reflection_matrix(v0, v1)
|
485 |
+
>>> point, normal = reflection_from_matrix(M0)
|
486 |
+
>>> M1 = reflection_matrix(point, normal)
|
487 |
+
>>> is_same_transform(M0, M1)
|
488 |
+
True
|
489 |
+
|
490 |
+
"""
|
491 |
+
M = numpy.array(matrix, dtype=numpy.float64, copy=False)
|
492 |
+
# normal: unit eigenvector corresponding to eigenvalue -1
|
493 |
+
l, V = numpy.linalg.eig(M[:3, :3])
|
494 |
+
i = numpy.where(abs(numpy.real(l) + 1.0) < 1e-8)[0]
|
495 |
+
if not len(i):
|
496 |
+
raise ValueError("no unit eigenvector corresponding to eigenvalue -1")
|
497 |
+
normal = numpy.real(V[:, i[0]]).squeeze()
|
498 |
+
# point: any unit eigenvector corresponding to eigenvalue 1
|
499 |
+
l, V = numpy.linalg.eig(M)
|
500 |
+
i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-8)[0]
|
501 |
+
if not len(i):
|
502 |
+
raise ValueError("no unit eigenvector corresponding to eigenvalue 1")
|
503 |
+
point = numpy.real(V[:, i[-1]]).squeeze()
|
504 |
+
point /= point[3]
|
505 |
+
return point, normal
|
506 |
+
|
507 |
+
|
508 |
+
def rotation_matrix(angle, direction, point=None):
|
509 |
+
"""Return matrix to rotate about axis defined by point and direction.
|
510 |
+
|
511 |
+
>>> angle = (random.random() - 0.5) * (2*math.pi)
|
512 |
+
>>> direc = numpy.random.random(3) - 0.5
|
513 |
+
>>> point = numpy.random.random(3) - 0.5
|
514 |
+
>>> R0 = rotation_matrix(angle, direc, point)
|
515 |
+
>>> R1 = rotation_matrix(angle-2*math.pi, direc, point)
|
516 |
+
>>> is_same_transform(R0, R1)
|
517 |
+
True
|
518 |
+
>>> R0 = rotation_matrix(angle, direc, point)
|
519 |
+
>>> R1 = rotation_matrix(-angle, -direc, point)
|
520 |
+
>>> is_same_transform(R0, R1)
|
521 |
+
True
|
522 |
+
>>> I = numpy.identity(4, numpy.float64)
|
523 |
+
>>> numpy.allclose(I, rotation_matrix(math.pi*2, direc))
|
524 |
+
True
|
525 |
+
>>> numpy.allclose(2., numpy.trace(rotation_matrix(math.pi/2,
|
526 |
+
... direc, point)))
|
527 |
+
True
|
528 |
+
|
529 |
+
"""
|
530 |
+
sina = math.sin(angle)
|
531 |
+
cosa = math.cos(angle)
|
532 |
+
direction = unit_vector(direction[:3])
|
533 |
+
# rotation matrix around unit vector
|
534 |
+
R = numpy.array(((cosa, 0.0, 0.0),
|
535 |
+
(0.0, cosa, 0.0),
|
536 |
+
(0.0, 0.0, cosa)), dtype=numpy.float64)
|
537 |
+
R += numpy.outer(direction, direction) * (1.0 - cosa)
|
538 |
+
direction *= sina
|
539 |
+
R += numpy.array(((0.0, -direction[2], direction[1]),
|
540 |
+
(direction[2], 0.0, -direction[0]),
|
541 |
+
(-direction[1], direction[0], 0.0)),
|
542 |
+
dtype=numpy.float64)
|
543 |
+
M = numpy.identity(4)
|
544 |
+
M[:3, :3] = R
|
545 |
+
if point is not None:
|
546 |
+
# rotation not around origin
|
547 |
+
point = numpy.array(point[:3], dtype=numpy.float64, copy=False)
|
548 |
+
M[:3, 3] = point - numpy.dot(R, point)
|
549 |
+
return M
|
550 |
+
|
551 |
+
|
552 |
+
def rotation_from_matrix(matrix):
|
553 |
+
"""Return rotation angle and axis from rotation matrix.
|
554 |
+
|
555 |
+
>>> angle = (random.random() - 0.5) * (2*math.pi)
|
556 |
+
>>> direc = numpy.random.random(3) - 0.5
|
557 |
+
>>> point = numpy.random.random(3) - 0.5
|
558 |
+
>>> R0 = rotation_matrix(angle, direc, point)
|
559 |
+
>>> angle, direc, point = rotation_from_matrix(R0)
|
560 |
+
>>> R1 = rotation_matrix(angle, direc, point)
|
561 |
+
>>> is_same_transform(R0, R1)
|
562 |
+
True
|
563 |
+
|
564 |
+
"""
|
565 |
+
R = numpy.array(matrix, dtype=numpy.float64, copy=False)
|
566 |
+
R33 = R[:3, :3]
|
567 |
+
# direction: unit eigenvector of R33 corresponding to eigenvalue of 1
|
568 |
+
l, W = numpy.linalg.eig(R33.T)
|
569 |
+
i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-8)[0]
|
570 |
+
if not len(i):
|
571 |
+
raise ValueError("no unit eigenvector corresponding to eigenvalue 1")
|
572 |
+
direction = numpy.real(W[:, i[-1]]).squeeze()
|
573 |
+
# point: unit eigenvector of R33 corresponding to eigenvalue of 1
|
574 |
+
l, Q = numpy.linalg.eig(R)
|
575 |
+
i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-8)[0]
|
576 |
+
if not len(i):
|
577 |
+
raise ValueError("no unit eigenvector corresponding to eigenvalue 1")
|
578 |
+
point = numpy.real(Q[:, i[-1]]).squeeze()
|
579 |
+
point /= point[3]
|
580 |
+
# rotation angle depending on direction
|
581 |
+
cosa = (numpy.trace(R33) - 1.0) / 2.0
|
582 |
+
if abs(direction[2]) > 1e-8:
|
583 |
+
sina = (R[1, 0] + (cosa-1.0)*direction[0]*direction[1]) / direction[2]
|
584 |
+
elif abs(direction[1]) > 1e-8:
|
585 |
+
sina = (R[0, 2] + (cosa-1.0)*direction[0]*direction[2]) / direction[1]
|
586 |
+
else:
|
587 |
+
sina = (R[2, 1] + (cosa-1.0)*direction[1]*direction[2]) / direction[0]
|
588 |
+
angle = math.atan2(sina, cosa)
|
589 |
+
return angle, direction, point
|
590 |
+
|
591 |
+
|
592 |
+
def scale_matrix(factor, origin=None, direction=None):
|
593 |
+
"""Return matrix to scale by factor around origin in direction.
|
594 |
+
|
595 |
+
Use factor -1 for point symmetry.
|
596 |
+
|
597 |
+
>>> v = (numpy.random.rand(4, 5) - 0.5) * 20.0
|
598 |
+
>>> v[3] = 1.0
|
599 |
+
>>> S = scale_matrix(-1.234)
|
600 |
+
>>> numpy.allclose(numpy.dot(S, v)[:3], -1.234*v[:3])
|
601 |
+
True
|
602 |
+
>>> factor = random.random() * 10 - 5
|
603 |
+
>>> origin = numpy.random.random(3) - 0.5
|
604 |
+
>>> direct = numpy.random.random(3) - 0.5
|
605 |
+
>>> S = scale_matrix(factor, origin)
|
606 |
+
>>> S = scale_matrix(factor, origin, direct)
|
607 |
+
|
608 |
+
"""
|
609 |
+
if direction is None:
|
610 |
+
# uniform scaling
|
611 |
+
M = numpy.array(((factor, 0.0, 0.0, 0.0),
|
612 |
+
(0.0, factor, 0.0, 0.0),
|
613 |
+
(0.0, 0.0, factor, 0.0),
|
614 |
+
(0.0, 0.0, 0.0, 1.0)), dtype=numpy.float64)
|
615 |
+
if origin is not None:
|
616 |
+
M[:3, 3] = origin[:3]
|
617 |
+
M[:3, 3] *= 1.0 - factor
|
618 |
+
else:
|
619 |
+
# nonuniform scaling
|
620 |
+
direction = unit_vector(direction[:3])
|
621 |
+
factor = 1.0 - factor
|
622 |
+
M = numpy.identity(4)
|
623 |
+
M[:3, :3] -= factor * numpy.outer(direction, direction)
|
624 |
+
if origin is not None:
|
625 |
+
M[:3, 3] = (factor * numpy.dot(origin[:3], direction)) * direction
|
626 |
+
return M
|
627 |
+
|
628 |
+
|
629 |
+
def scale_from_matrix(matrix):
|
630 |
+
"""Return scaling factor, origin and direction from scaling matrix.
|
631 |
+
|
632 |
+
>>> factor = random.random() * 10 - 5
|
633 |
+
>>> origin = numpy.random.random(3) - 0.5
|
634 |
+
>>> direct = numpy.random.random(3) - 0.5
|
635 |
+
>>> S0 = scale_matrix(factor, origin)
|
636 |
+
>>> factor, origin, direction = scale_from_matrix(S0)
|
637 |
+
>>> S1 = scale_matrix(factor, origin, direction)
|
638 |
+
>>> is_same_transform(S0, S1)
|
639 |
+
True
|
640 |
+
>>> S0 = scale_matrix(factor, origin, direct)
|
641 |
+
>>> factor, origin, direction = scale_from_matrix(S0)
|
642 |
+
>>> S1 = scale_matrix(factor, origin, direction)
|
643 |
+
>>> is_same_transform(S0, S1)
|
644 |
+
True
|
645 |
+
|
646 |
+
"""
|
647 |
+
M = numpy.array(matrix, dtype=numpy.float64, copy=False)
|
648 |
+
M33 = M[:3, :3]
|
649 |
+
factor = numpy.trace(M33) - 2.0
|
650 |
+
try:
|
651 |
+
# direction: unit eigenvector corresponding to eigenvalue factor
|
652 |
+
l, V = numpy.linalg.eig(M33)
|
653 |
+
i = numpy.where(abs(numpy.real(l) - factor) < 1e-8)[0][0]
|
654 |
+
direction = numpy.real(V[:, i]).squeeze()
|
655 |
+
direction /= vector_norm(direction)
|
656 |
+
except IndexError:
|
657 |
+
# uniform scaling
|
658 |
+
factor = (factor + 2.0) / 3.0
|
659 |
+
direction = None
|
660 |
+
# origin: any eigenvector corresponding to eigenvalue 1
|
661 |
+
l, V = numpy.linalg.eig(M)
|
662 |
+
i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-8)[0]
|
663 |
+
if not len(i):
|
664 |
+
raise ValueError("no eigenvector corresponding to eigenvalue 1")
|
665 |
+
origin = numpy.real(V[:, i[-1]]).squeeze()
|
666 |
+
origin /= origin[3]
|
667 |
+
return factor, origin, direction
|
668 |
+
|
669 |
+
|
670 |
+
def projection_matrix(point, normal, direction=None,
|
671 |
+
perspective=None, pseudo=False):
|
672 |
+
"""Return matrix to project onto plane defined by point and normal.
|
673 |
+
|
674 |
+
Using either perspective point, projection direction, or none of both.
|
675 |
+
|
676 |
+
If pseudo is True, perspective projections will preserve relative depth
|
677 |
+
such that Perspective = dot(Orthogonal, PseudoPerspective).
|
678 |
+
|
679 |
+
>>> P = projection_matrix((0, 0, 0), (1, 0, 0))
|
680 |
+
>>> numpy.allclose(P[1:, 1:], numpy.identity(4)[1:, 1:])
|
681 |
+
True
|
682 |
+
>>> point = numpy.random.random(3) - 0.5
|
683 |
+
>>> normal = numpy.random.random(3) - 0.5
|
684 |
+
>>> direct = numpy.random.random(3) - 0.5
|
685 |
+
>>> persp = numpy.random.random(3) - 0.5
|
686 |
+
>>> P0 = projection_matrix(point, normal)
|
687 |
+
>>> P1 = projection_matrix(point, normal, direction=direct)
|
688 |
+
>>> P2 = projection_matrix(point, normal, perspective=persp)
|
689 |
+
>>> P3 = projection_matrix(point, normal, perspective=persp, pseudo=True)
|
690 |
+
>>> is_same_transform(P2, numpy.dot(P0, P3))
|
691 |
+
True
|
692 |
+
>>> P = projection_matrix((3, 0, 0), (1, 1, 0), (1, 0, 0))
|
693 |
+
>>> v0 = (numpy.random.rand(4, 5) - 0.5) * 20.0
|
694 |
+
>>> v0[3] = 1.0
|
695 |
+
>>> v1 = numpy.dot(P, v0)
|
696 |
+
>>> numpy.allclose(v1[1], v0[1])
|
697 |
+
True
|
698 |
+
>>> numpy.allclose(v1[0], 3.0-v1[1])
|
699 |
+
True
|
700 |
+
|
701 |
+
"""
|
702 |
+
M = numpy.identity(4)
|
703 |
+
point = numpy.array(point[:3], dtype=numpy.float64, copy=False)
|
704 |
+
normal = unit_vector(normal[:3])
|
705 |
+
if perspective is not None:
|
706 |
+
# perspective projection
|
707 |
+
perspective = numpy.array(perspective[:3], dtype=numpy.float64,
|
708 |
+
copy=False)
|
709 |
+
M[0, 0] = M[1, 1] = M[2, 2] = numpy.dot(perspective-point, normal)
|
710 |
+
M[:3, :3] -= numpy.outer(perspective, normal)
|
711 |
+
if pseudo:
|
712 |
+
# preserve relative depth
|
713 |
+
M[:3, :3] -= numpy.outer(normal, normal)
|
714 |
+
M[:3, 3] = numpy.dot(point, normal) * (perspective+normal)
|
715 |
+
else:
|
716 |
+
M[:3, 3] = numpy.dot(point, normal) * perspective
|
717 |
+
M[3, :3] = -normal
|
718 |
+
M[3, 3] = numpy.dot(perspective, normal)
|
719 |
+
elif direction is not None:
|
720 |
+
# parallel projection
|
721 |
+
direction = numpy.array(direction[:3], dtype=numpy.float64, copy=False)
|
722 |
+
scale = numpy.dot(direction, normal)
|
723 |
+
M[:3, :3] -= numpy.outer(direction, normal) / scale
|
724 |
+
M[:3, 3] = direction * (numpy.dot(point, normal) / scale)
|
725 |
+
else:
|
726 |
+
# orthogonal projection
|
727 |
+
M[:3, :3] -= numpy.outer(normal, normal)
|
728 |
+
M[:3, 3] = numpy.dot(point, normal) * normal
|
729 |
+
return M
|
730 |
+
|
731 |
+
|
732 |
+
def projection_from_matrix(matrix, pseudo=False):
|
733 |
+
"""Return projection plane and perspective point from projection matrix.
|
734 |
+
|
735 |
+
Return values are same as arguments for projection_matrix function:
|
736 |
+
point, normal, direction, perspective, and pseudo.
|
737 |
+
|
738 |
+
>>> point = numpy.random.random(3) - 0.5
|
739 |
+
>>> normal = numpy.random.random(3) - 0.5
|
740 |
+
>>> direct = numpy.random.random(3) - 0.5
|
741 |
+
>>> persp = numpy.random.random(3) - 0.5
|
742 |
+
>>> P0 = projection_matrix(point, normal)
|
743 |
+
>>> result = projection_from_matrix(P0)
|
744 |
+
>>> P1 = projection_matrix(*result)
|
745 |
+
>>> is_same_transform(P0, P1)
|
746 |
+
True
|
747 |
+
>>> P0 = projection_matrix(point, normal, direct)
|
748 |
+
>>> result = projection_from_matrix(P0)
|
749 |
+
>>> P1 = projection_matrix(*result)
|
750 |
+
>>> is_same_transform(P0, P1)
|
751 |
+
True
|
752 |
+
>>> P0 = projection_matrix(point, normal, perspective=persp, pseudo=False)
|
753 |
+
>>> result = projection_from_matrix(P0, pseudo=False)
|
754 |
+
>>> P1 = projection_matrix(*result)
|
755 |
+
>>> is_same_transform(P0, P1)
|
756 |
+
True
|
757 |
+
>>> P0 = projection_matrix(point, normal, perspective=persp, pseudo=True)
|
758 |
+
>>> result = projection_from_matrix(P0, pseudo=True)
|
759 |
+
>>> P1 = projection_matrix(*result)
|
760 |
+
>>> is_same_transform(P0, P1)
|
761 |
+
True
|
762 |
+
|
763 |
+
"""
|
764 |
+
M = numpy.array(matrix, dtype=numpy.float64, copy=False)
|
765 |
+
M33 = M[:3, :3]
|
766 |
+
l, V = numpy.linalg.eig(M)
|
767 |
+
i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-8)[0]
|
768 |
+
if not pseudo and len(i):
|
769 |
+
# point: any eigenvector corresponding to eigenvalue 1
|
770 |
+
point = numpy.real(V[:, i[-1]]).squeeze()
|
771 |
+
point /= point[3]
|
772 |
+
# direction: unit eigenvector corresponding to eigenvalue 0
|
773 |
+
l, V = numpy.linalg.eig(M33)
|
774 |
+
i = numpy.where(abs(numpy.real(l)) < 1e-8)[0]
|
775 |
+
if not len(i):
|
776 |
+
raise ValueError("no eigenvector corresponding to eigenvalue 0")
|
777 |
+
direction = numpy.real(V[:, i[0]]).squeeze()
|
778 |
+
direction /= vector_norm(direction)
|
779 |
+
# normal: unit eigenvector of M33.T corresponding to eigenvalue 0
|
780 |
+
l, V = numpy.linalg.eig(M33.T)
|
781 |
+
i = numpy.where(abs(numpy.real(l)) < 1e-8)[0]
|
782 |
+
if len(i):
|
783 |
+
# parallel projection
|
784 |
+
normal = numpy.real(V[:, i[0]]).squeeze()
|
785 |
+
normal /= vector_norm(normal)
|
786 |
+
return point, normal, direction, None, False
|
787 |
+
else:
|
788 |
+
# orthogonal projection, where normal equals direction vector
|
789 |
+
return point, direction, None, None, False
|
790 |
+
else:
|
791 |
+
# perspective projection
|
792 |
+
i = numpy.where(abs(numpy.real(l)) > 1e-8)[0]
|
793 |
+
if not len(i):
|
794 |
+
raise ValueError(
|
795 |
+
"no eigenvector not corresponding to eigenvalue 0")
|
796 |
+
point = numpy.real(V[:, i[-1]]).squeeze()
|
797 |
+
point /= point[3]
|
798 |
+
normal = - M[3, :3]
|
799 |
+
perspective = M[:3, 3] / numpy.dot(point[:3], normal)
|
800 |
+
if pseudo:
|
801 |
+
perspective -= normal
|
802 |
+
return point, normal, None, perspective, pseudo
|
803 |
+
|
804 |
+
|
805 |
+
def clip_matrix(left, right, bottom, top, near, far, perspective=False):
|
806 |
+
"""Return matrix to obtain normalized device coordinates from frustrum.
|
807 |
+
|
808 |
+
The frustrum bounds are axis-aligned along x (left, right),
|
809 |
+
y (bottom, top) and z (near, far).
|
810 |
+
|
811 |
+
Normalized device coordinates are in range [-1, 1] if coordinates are
|
812 |
+
inside the frustrum.
|
813 |
+
|
814 |
+
If perspective is True the frustrum is a truncated pyramid with the
|
815 |
+
perspective point at origin and direction along z axis, otherwise an
|
816 |
+
orthographic canonical view volume (a box).
|
817 |
+
|
818 |
+
Homogeneous coordinates transformed by the perspective clip matrix
|
819 |
+
need to be dehomogenized (devided by w coordinate).
|
820 |
+
|
821 |
+
>>> frustrum = numpy.random.rand(6)
|
822 |
+
>>> frustrum[1] += frustrum[0]
|
823 |
+
>>> frustrum[3] += frustrum[2]
|
824 |
+
>>> frustrum[5] += frustrum[4]
|
825 |
+
>>> M = clip_matrix(*frustrum, perspective=False)
|
826 |
+
>>> numpy.dot(M, [frustrum[0], frustrum[2], frustrum[4], 1.0])
|
827 |
+
array([-1., -1., -1., 1.])
|
828 |
+
>>> numpy.dot(M, [frustrum[1], frustrum[3], frustrum[5], 1.0])
|
829 |
+
array([ 1., 1., 1., 1.])
|
830 |
+
>>> M = clip_matrix(*frustrum, perspective=True)
|
831 |
+
>>> v = numpy.dot(M, [frustrum[0], frustrum[2], frustrum[4], 1.0])
|
832 |
+
>>> v / v[3]
|
833 |
+
array([-1., -1., -1., 1.])
|
834 |
+
>>> v = numpy.dot(M, [frustrum[1], frustrum[3], frustrum[4], 1.0])
|
835 |
+
>>> v / v[3]
|
836 |
+
array([ 1., 1., -1., 1.])
|
837 |
+
|
838 |
+
"""
|
839 |
+
if left >= right or bottom >= top or near >= far:
|
840 |
+
raise ValueError("invalid frustrum")
|
841 |
+
if perspective:
|
842 |
+
if near <= _EPS:
|
843 |
+
raise ValueError("invalid frustrum: near <= 0")
|
844 |
+
t = 2.0 * near
|
845 |
+
M = ((-t/(right-left), 0.0, (right+left)/(right-left), 0.0),
|
846 |
+
(0.0, -t/(top-bottom), (top+bottom)/(top-bottom), 0.0),
|
847 |
+
(0.0, 0.0, -(far+near)/(far-near), t*far/(far-near)),
|
848 |
+
(0.0, 0.0, -1.0, 0.0))
|
849 |
+
else:
|
850 |
+
M = ((2.0/(right-left), 0.0, 0.0, (right+left)/(left-right)),
|
851 |
+
(0.0, 2.0/(top-bottom), 0.0, (top+bottom)/(bottom-top)),
|
852 |
+
(0.0, 0.0, 2.0/(far-near), (far+near)/(near-far)),
|
853 |
+
(0.0, 0.0, 0.0, 1.0))
|
854 |
+
return numpy.array(M, dtype=numpy.float64)
|
855 |
+
|
856 |
+
|
857 |
+
def shear_matrix(angle, direction, point, normal):
|
858 |
+
"""Return matrix to shear by angle along direction vector on shear plane.
|
859 |
+
|
860 |
+
The shear plane is defined by a point and normal vector. The direction
|
861 |
+
vector must be orthogonal to the plane's normal vector.
|
862 |
+
|
863 |
+
A point P is transformed by the shear matrix into P" such that
|
864 |
+
the vector P-P" is parallel to the direction vector and its extent is
|
865 |
+
given by the angle of P-P'-P", where P' is the orthogonal projection
|
866 |
+
of P onto the shear plane.
|
867 |
+
|
868 |
+
>>> angle = (random.random() - 0.5) * 4*math.pi
|
869 |
+
>>> direct = numpy.random.random(3) - 0.5
|
870 |
+
>>> point = numpy.random.random(3) - 0.5
|
871 |
+
>>> normal = numpy.cross(direct, numpy.random.random(3))
|
872 |
+
>>> S = shear_matrix(angle, direct, point, normal)
|
873 |
+
>>> numpy.allclose(1.0, numpy.linalg.det(S))
|
874 |
+
True
|
875 |
+
|
876 |
+
"""
|
877 |
+
normal = unit_vector(normal[:3])
|
878 |
+
direction = unit_vector(direction[:3])
|
879 |
+
if abs(numpy.dot(normal, direction)) > 1e-6:
|
880 |
+
raise ValueError("direction and normal vectors are not orthogonal")
|
881 |
+
angle = math.tan(angle)
|
882 |
+
M = numpy.identity(4)
|
883 |
+
M[:3, :3] += angle * numpy.outer(direction, normal)
|
884 |
+
M[:3, 3] = -angle * numpy.dot(point[:3], normal) * direction
|
885 |
+
return M
|
886 |
+
|
887 |
+
|
888 |
+
def shear_from_matrix(matrix):
|
889 |
+
"""Return shear angle, direction and plane from shear matrix.
|
890 |
+
|
891 |
+
>>> angle = (random.random() - 0.5) * 4*math.pi
|
892 |
+
>>> direct = numpy.random.random(3) - 0.5
|
893 |
+
>>> point = numpy.random.random(3) - 0.5
|
894 |
+
>>> normal = numpy.cross(direct, numpy.random.random(3))
|
895 |
+
>>> S0 = shear_matrix(angle, direct, point, normal)
|
896 |
+
>>> angle, direct, point, normal = shear_from_matrix(S0)
|
897 |
+
>>> S1 = shear_matrix(angle, direct, point, normal)
|
898 |
+
>>> is_same_transform(S0, S1)
|
899 |
+
True
|
900 |
+
|
901 |
+
"""
|
902 |
+
M = numpy.array(matrix, dtype=numpy.float64, copy=False)
|
903 |
+
M33 = M[:3, :3]
|
904 |
+
# normal: cross independent eigenvectors corresponding to the eigenvalue 1
|
905 |
+
l, V = numpy.linalg.eig(M33)
|
906 |
+
i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-4)[0]
|
907 |
+
if len(i) < 2:
|
908 |
+
raise ValueError("No two linear independent eigenvectors found %s" % l)
|
909 |
+
V = numpy.real(V[:, i]).squeeze().T
|
910 |
+
lenorm = -1.0
|
911 |
+
for i0, i1 in ((0, 1), (0, 2), (1, 2)):
|
912 |
+
n = numpy.cross(V[i0], V[i1])
|
913 |
+
l = vector_norm(n)
|
914 |
+
if l > lenorm:
|
915 |
+
lenorm = l
|
916 |
+
normal = n
|
917 |
+
normal /= lenorm
|
918 |
+
# direction and angle
|
919 |
+
direction = numpy.dot(M33 - numpy.identity(3), normal)
|
920 |
+
angle = vector_norm(direction)
|
921 |
+
direction /= angle
|
922 |
+
angle = math.atan(angle)
|
923 |
+
# point: eigenvector corresponding to eigenvalue 1
|
924 |
+
l, V = numpy.linalg.eig(M)
|
925 |
+
i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-8)[0]
|
926 |
+
if not len(i):
|
927 |
+
raise ValueError("no eigenvector corresponding to eigenvalue 1")
|
928 |
+
point = numpy.real(V[:, i[-1]]).squeeze()
|
929 |
+
point /= point[3]
|
930 |
+
return angle, direction, point, normal
|
931 |
+
|
932 |
+
|
933 |
+
def decompose_matrix(matrix):
|
934 |
+
"""Return sequence of transformations from transformation matrix.
|
935 |
+
|
936 |
+
matrix : array_like
|
937 |
+
Non-degenerative homogeneous transformation matrix
|
938 |
+
|
939 |
+
Return tuple of:
|
940 |
+
scale : vector of 3 scaling factors
|
941 |
+
shear : list of shear factors for x-y, x-z, y-z axes
|
942 |
+
angles : list of Euler angles about static x, y, z axes
|
943 |
+
translate : translation vector along x, y, z axes
|
944 |
+
perspective : perspective partition of matrix
|
945 |
+
|
946 |
+
Raise ValueError if matrix is of wrong type or degenerative.
|
947 |
+
|
948 |
+
>>> T0 = translation_matrix((1, 2, 3))
|
949 |
+
>>> scale, shear, angles, trans, persp = decompose_matrix(T0)
|
950 |
+
>>> T1 = translation_matrix(trans)
|
951 |
+
>>> numpy.allclose(T0, T1)
|
952 |
+
True
|
953 |
+
>>> S = scale_matrix(0.123)
|
954 |
+
>>> scale, shear, angles, trans, persp = decompose_matrix(S)
|
955 |
+
>>> scale[0]
|
956 |
+
0.123
|
957 |
+
>>> R0 = euler_matrix(1, 2, 3)
|
958 |
+
>>> scale, shear, angles, trans, persp = decompose_matrix(R0)
|
959 |
+
>>> R1 = euler_matrix(*angles)
|
960 |
+
>>> numpy.allclose(R0, R1)
|
961 |
+
True
|
962 |
+
|
963 |
+
"""
|
964 |
+
M = numpy.array(matrix, dtype=numpy.float64, copy=True).T
|
965 |
+
if abs(M[3, 3]) < _EPS:
|
966 |
+
raise ValueError("M[3, 3] is zero")
|
967 |
+
M /= M[3, 3]
|
968 |
+
P = M.copy()
|
969 |
+
P[:, 3] = 0, 0, 0, 1
|
970 |
+
if not numpy.linalg.det(P):
|
971 |
+
raise ValueError("Matrix is singular")
|
972 |
+
|
973 |
+
scale = numpy.zeros((3, ), dtype=numpy.float64)
|
974 |
+
shear = [0, 0, 0]
|
975 |
+
angles = [0, 0, 0]
|
976 |
+
|
977 |
+
if any(abs(M[:3, 3]) > _EPS):
|
978 |
+
perspective = numpy.dot(M[:, 3], numpy.linalg.inv(P.T))
|
979 |
+
M[:, 3] = 0, 0, 0, 1
|
980 |
+
else:
|
981 |
+
perspective = numpy.array((0, 0, 0, 1), dtype=numpy.float64)
|
982 |
+
|
983 |
+
translate = M[3, :3].copy()
|
984 |
+
M[3, :3] = 0
|
985 |
+
|
986 |
+
row = M[:3, :3].copy()
|
987 |
+
scale[0] = vector_norm(row[0])
|
988 |
+
row[0] /= scale[0]
|
989 |
+
shear[0] = numpy.dot(row[0], row[1])
|
990 |
+
row[1] -= row[0] * shear[0]
|
991 |
+
scale[1] = vector_norm(row[1])
|
992 |
+
row[1] /= scale[1]
|
993 |
+
shear[0] /= scale[1]
|
994 |
+
shear[1] = numpy.dot(row[0], row[2])
|
995 |
+
row[2] -= row[0] * shear[1]
|
996 |
+
shear[2] = numpy.dot(row[1], row[2])
|
997 |
+
row[2] -= row[1] * shear[2]
|
998 |
+
scale[2] = vector_norm(row[2])
|
999 |
+
row[2] /= scale[2]
|
1000 |
+
shear[1:] /= scale[2]
|
1001 |
+
|
1002 |
+
if numpy.dot(row[0], numpy.cross(row[1], row[2])) < 0:
|
1003 |
+
scale *= -1
|
1004 |
+
row *= -1
|
1005 |
+
|
1006 |
+
angles[1] = math.asin(-row[0, 2])
|
1007 |
+
if math.cos(angles[1]):
|
1008 |
+
angles[0] = math.atan2(row[1, 2], row[2, 2])
|
1009 |
+
angles[2] = math.atan2(row[0, 1], row[0, 0])
|
1010 |
+
else:
|
1011 |
+
#angles[0] = math.atan2(row[1, 0], row[1, 1])
|
1012 |
+
angles[0] = math.atan2(-row[2, 1], row[1, 1])
|
1013 |
+
angles[2] = 0.0
|
1014 |
+
|
1015 |
+
return scale, shear, angles, translate, perspective
|
1016 |
+
|
1017 |
+
|
1018 |
+
def compose_matrix(scale=None, shear=None, angles=None, translate=None,
|
1019 |
+
perspective=None):
|
1020 |
+
"""Return transformation matrix from sequence of transformations.
|
1021 |
+
|
1022 |
+
This is the inverse of the decompose_matrix function.
|
1023 |
+
|
1024 |
+
Sequence of transformations:
|
1025 |
+
scale : vector of 3 scaling factors
|
1026 |
+
shear : list of shear factors for x-y, x-z, y-z axes
|
1027 |
+
angles : list of Euler angles about static x, y, z axes
|
1028 |
+
translate : translation vector along x, y, z axes
|
1029 |
+
perspective : perspective partition of matrix
|
1030 |
+
|
1031 |
+
>>> scale = numpy.random.random(3) - 0.5
|
1032 |
+
>>> shear = numpy.random.random(3) - 0.5
|
1033 |
+
>>> angles = (numpy.random.random(3) - 0.5) * (2*math.pi)
|
1034 |
+
>>> trans = numpy.random.random(3) - 0.5
|
1035 |
+
>>> persp = numpy.random.random(4) - 0.5
|
1036 |
+
>>> M0 = compose_matrix(scale, shear, angles, trans, persp)
|
1037 |
+
>>> result = decompose_matrix(M0)
|
1038 |
+
>>> M1 = compose_matrix(*result)
|
1039 |
+
>>> is_same_transform(M0, M1)
|
1040 |
+
True
|
1041 |
+
|
1042 |
+
"""
|
1043 |
+
M = numpy.identity(4)
|
1044 |
+
if perspective is not None:
|
1045 |
+
P = numpy.identity(4)
|
1046 |
+
P[3, :] = perspective[:4]
|
1047 |
+
M = numpy.dot(M, P)
|
1048 |
+
if translate is not None:
|
1049 |
+
T = numpy.identity(4)
|
1050 |
+
T[:3, 3] = translate[:3]
|
1051 |
+
M = numpy.dot(M, T)
|
1052 |
+
if angles is not None:
|
1053 |
+
R = euler_matrix(angles[0], angles[1], angles[2], 'sxyz')
|
1054 |
+
M = numpy.dot(M, R)
|
1055 |
+
if shear is not None:
|
1056 |
+
Z = numpy.identity(4)
|
1057 |
+
Z[1, 2] = shear[2]
|
1058 |
+
Z[0, 2] = shear[1]
|
1059 |
+
Z[0, 1] = shear[0]
|
1060 |
+
M = numpy.dot(M, Z)
|
1061 |
+
if scale is not None:
|
1062 |
+
S = numpy.identity(4)
|
1063 |
+
S[0, 0] = scale[0]
|
1064 |
+
S[1, 1] = scale[1]
|
1065 |
+
S[2, 2] = scale[2]
|
1066 |
+
M = numpy.dot(M, S)
|
1067 |
+
M /= M[3, 3]
|
1068 |
+
return M
|
1069 |
+
|
1070 |
+
|
1071 |
+
def orthogonalization_matrix(lengths, angles):
|
1072 |
+
"""Return orthogonalization matrix for crystallographic cell coordinates.
|
1073 |
+
|
1074 |
+
Angles are expected in degrees.
|
1075 |
+
|
1076 |
+
The de-orthogonalization matrix is the inverse.
|
1077 |
+
|
1078 |
+
>>> O = orthogonalization_matrix((10., 10., 10.), (90., 90., 90.))
|
1079 |
+
>>> numpy.allclose(O[:3, :3], numpy.identity(3, float) * 10)
|
1080 |
+
True
|
1081 |
+
>>> O = orthogonalization_matrix([9.8, 12.0, 15.5], [87.2, 80.7, 69.7])
|
1082 |
+
>>> numpy.allclose(numpy.sum(O), 43.063229)
|
1083 |
+
True
|
1084 |
+
|
1085 |
+
"""
|
1086 |
+
a, b, c = lengths
|
1087 |
+
angles = numpy.radians(angles)
|
1088 |
+
sina, sinb, _ = numpy.sin(angles)
|
1089 |
+
cosa, cosb, cosg = numpy.cos(angles)
|
1090 |
+
co = (cosa * cosb - cosg) / (sina * sinb)
|
1091 |
+
return numpy.array((
|
1092 |
+
(a*sinb*math.sqrt(1.0-co*co), 0.0, 0.0, 0.0),
|
1093 |
+
(-a*sinb*co, b*sina, 0.0, 0.0),
|
1094 |
+
(a*cosb, b*cosa, c, 0.0),
|
1095 |
+
(0.0, 0.0, 0.0, 1.0)),
|
1096 |
+
dtype=numpy.float64)
|
1097 |
+
|
1098 |
+
|
1099 |
+
def superimposition_matrix(v0, v1, scaling=False, usesvd=True):
|
1100 |
+
"""Return matrix to transform given vector set into second vector set.
|
1101 |
+
|
1102 |
+
v0 and v1 are shape (3, \*) or (4, \*) arrays of at least 3 vectors.
|
1103 |
+
|
1104 |
+
If usesvd is True, the weighted sum of squared deviations (RMSD) is
|
1105 |
+
minimized according to the algorithm by W. Kabsch [8]. Otherwise the
|
1106 |
+
quaternion based algorithm by B. Horn [9] is used (slower when using
|
1107 |
+
this Python implementation).
|
1108 |
+
|
1109 |
+
The returned matrix performs rotation, translation and uniform scaling
|
1110 |
+
(if specified).
|
1111 |
+
|
1112 |
+
>>> v0 = numpy.random.rand(3, 10)
|
1113 |
+
>>> M = superimposition_matrix(v0, v0)
|
1114 |
+
>>> numpy.allclose(M, numpy.identity(4))
|
1115 |
+
True
|
1116 |
+
>>> R = random_rotation_matrix(numpy.random.random(3))
|
1117 |
+
>>> v0 = ((1,0,0), (0,1,0), (0,0,1), (1,1,1))
|
1118 |
+
>>> v1 = numpy.dot(R, v0)
|
1119 |
+
>>> M = superimposition_matrix(v0, v1)
|
1120 |
+
>>> numpy.allclose(v1, numpy.dot(M, v0))
|
1121 |
+
True
|
1122 |
+
>>> v0 = (numpy.random.rand(4, 100) - 0.5) * 20.0
|
1123 |
+
>>> v0[3] = 1.0
|
1124 |
+
>>> v1 = numpy.dot(R, v0)
|
1125 |
+
>>> M = superimposition_matrix(v0, v1)
|
1126 |
+
>>> numpy.allclose(v1, numpy.dot(M, v0))
|
1127 |
+
True
|
1128 |
+
>>> S = scale_matrix(random.random())
|
1129 |
+
>>> T = translation_matrix(numpy.random.random(3)-0.5)
|
1130 |
+
>>> M = concatenate_matrices(T, R, S)
|
1131 |
+
>>> v1 = numpy.dot(M, v0)
|
1132 |
+
>>> v0[:3] += numpy.random.normal(0.0, 1e-9, 300).reshape(3, -1)
|
1133 |
+
>>> M = superimposition_matrix(v0, v1, scaling=True)
|
1134 |
+
>>> numpy.allclose(v1, numpy.dot(M, v0))
|
1135 |
+
True
|
1136 |
+
>>> M = superimposition_matrix(v0, v1, scaling=True, usesvd=False)
|
1137 |
+
>>> numpy.allclose(v1, numpy.dot(M, v0))
|
1138 |
+
True
|
1139 |
+
>>> v = numpy.empty((4, 100, 3), dtype=numpy.float64)
|
1140 |
+
>>> v[:, :, 0] = v0
|
1141 |
+
>>> M = superimposition_matrix(v0, v1, scaling=True, usesvd=False)
|
1142 |
+
>>> numpy.allclose(v1, numpy.dot(M, v[:, :, 0]))
|
1143 |
+
True
|
1144 |
+
|
1145 |
+
"""
|
1146 |
+
v0 = numpy.array(v0, dtype=numpy.float64, copy=False)[:3]
|
1147 |
+
v1 = numpy.array(v1, dtype=numpy.float64, copy=False)[:3]
|
1148 |
+
|
1149 |
+
if v0.shape != v1.shape or v0.shape[1] < 3:
|
1150 |
+
raise ValueError("Vector sets are of wrong shape or type.")
|
1151 |
+
|
1152 |
+
# move centroids to origin
|
1153 |
+
t0 = numpy.mean(v0, axis=1)
|
1154 |
+
t1 = numpy.mean(v1, axis=1)
|
1155 |
+
v0 = v0 - t0.reshape(3, 1)
|
1156 |
+
v1 = v1 - t1.reshape(3, 1)
|
1157 |
+
|
1158 |
+
if usesvd:
|
1159 |
+
# Singular Value Decomposition of covariance matrix
|
1160 |
+
u, s, vh = numpy.linalg.svd(numpy.dot(v1, v0.T))
|
1161 |
+
# rotation matrix from SVD orthonormal bases
|
1162 |
+
R = numpy.dot(u, vh)
|
1163 |
+
if numpy.linalg.det(R) < 0.0:
|
1164 |
+
# R does not constitute right handed system
|
1165 |
+
R -= numpy.outer(u[:, 2], vh[2, :]*2.0)
|
1166 |
+
s[-1] *= -1.0
|
1167 |
+
# homogeneous transformation matrix
|
1168 |
+
M = numpy.identity(4)
|
1169 |
+
M[:3, :3] = R
|
1170 |
+
else:
|
1171 |
+
# compute symmetric matrix N
|
1172 |
+
xx, yy, zz = numpy.sum(v0 * v1, axis=1)
|
1173 |
+
xy, yz, zx = numpy.sum(v0 * numpy.roll(v1, -1, axis=0), axis=1)
|
1174 |
+
xz, yx, zy = numpy.sum(v0 * numpy.roll(v1, -2, axis=0), axis=1)
|
1175 |
+
N = ((xx+yy+zz, yz-zy, zx-xz, xy-yx),
|
1176 |
+
(yz-zy, xx-yy-zz, xy+yx, zx+xz),
|
1177 |
+
(zx-xz, xy+yx, -xx+yy-zz, yz+zy),
|
1178 |
+
(xy-yx, zx+xz, yz+zy, -xx-yy+zz))
|
1179 |
+
# quaternion: eigenvector corresponding to most positive eigenvalue
|
1180 |
+
l, V = numpy.linalg.eig(N)
|
1181 |
+
q = V[:, numpy.argmax(l)]
|
1182 |
+
q /= vector_norm(q) # unit quaternion
|
1183 |
+
q = numpy.roll(q, -1) # move w component to end
|
1184 |
+
# homogeneous transformation matrix
|
1185 |
+
M = quaternion_matrix(q)
|
1186 |
+
|
1187 |
+
# scale: ratio of rms deviations from centroid
|
1188 |
+
if scaling:
|
1189 |
+
v0 *= v0
|
1190 |
+
v1 *= v1
|
1191 |
+
M[:3, :3] *= math.sqrt(numpy.sum(v1) / numpy.sum(v0))
|
1192 |
+
|
1193 |
+
# translation
|
1194 |
+
M[:3, 3] = t1
|
1195 |
+
T = numpy.identity(4)
|
1196 |
+
T[:3, 3] = -t0
|
1197 |
+
M = numpy.dot(M, T)
|
1198 |
+
return M
|
1199 |
+
|
1200 |
+
|
1201 |
+
def euler_matrix(ai, aj, ak, axes='sxyz'):
|
1202 |
+
"""Return homogeneous rotation matrix from Euler angles and axis sequence.
|
1203 |
+
|
1204 |
+
ai, aj, ak : Euler's roll, pitch and yaw angles
|
1205 |
+
axes : One of 24 axis sequences as string or encoded tuple
|
1206 |
+
|
1207 |
+
>>> R = euler_matrix(1, 2, 3, 'syxz')
|
1208 |
+
>>> numpy.allclose(numpy.sum(R[0]), -1.34786452)
|
1209 |
+
True
|
1210 |
+
>>> R = euler_matrix(1, 2, 3, (0, 1, 0, 1))
|
1211 |
+
>>> numpy.allclose(numpy.sum(R[0]), -0.383436184)
|
1212 |
+
True
|
1213 |
+
>>> ai, aj, ak = (4.0*math.pi) * (numpy.random.random(3) - 0.5)
|
1214 |
+
>>> for axes in _AXES2TUPLE.keys():
|
1215 |
+
... R = euler_matrix(ai, aj, ak, axes)
|
1216 |
+
>>> for axes in _TUPLE2AXES.keys():
|
1217 |
+
... R = euler_matrix(ai, aj, ak, axes)
|
1218 |
+
|
1219 |
+
"""
|
1220 |
+
try:
|
1221 |
+
firstaxis, parity, repetition, frame = _AXES2TUPLE[axes]
|
1222 |
+
except (AttributeError, KeyError):
|
1223 |
+
_ = _TUPLE2AXES[axes]
|
1224 |
+
firstaxis, parity, repetition, frame = axes
|
1225 |
+
|
1226 |
+
i = firstaxis
|
1227 |
+
j = _NEXT_AXIS[i+parity]
|
1228 |
+
k = _NEXT_AXIS[i-parity+1]
|
1229 |
+
|
1230 |
+
if frame:
|
1231 |
+
ai, ak = ak, ai
|
1232 |
+
if parity:
|
1233 |
+
ai, aj, ak = -ai, -aj, -ak
|
1234 |
+
|
1235 |
+
si, sj, sk = math.sin(ai), math.sin(aj), math.sin(ak)
|
1236 |
+
ci, cj, ck = math.cos(ai), math.cos(aj), math.cos(ak)
|
1237 |
+
cc, cs = ci*ck, ci*sk
|
1238 |
+
sc, ss = si*ck, si*sk
|
1239 |
+
|
1240 |
+
M = numpy.identity(4)
|
1241 |
+
if repetition:
|
1242 |
+
M[i, i] = cj
|
1243 |
+
M[i, j] = sj*si
|
1244 |
+
M[i, k] = sj*ci
|
1245 |
+
M[j, i] = sj*sk
|
1246 |
+
M[j, j] = -cj*ss+cc
|
1247 |
+
M[j, k] = -cj*cs-sc
|
1248 |
+
M[k, i] = -sj*ck
|
1249 |
+
M[k, j] = cj*sc+cs
|
1250 |
+
M[k, k] = cj*cc-ss
|
1251 |
+
else:
|
1252 |
+
M[i, i] = cj*ck
|
1253 |
+
M[i, j] = sj*sc-cs
|
1254 |
+
M[i, k] = sj*cc+ss
|
1255 |
+
M[j, i] = cj*sk
|
1256 |
+
M[j, j] = sj*ss+cc
|
1257 |
+
M[j, k] = sj*cs-sc
|
1258 |
+
M[k, i] = -sj
|
1259 |
+
M[k, j] = cj*si
|
1260 |
+
M[k, k] = cj*ci
|
1261 |
+
return M
|
1262 |
+
|
1263 |
+
|
1264 |
+
def euler_from_matrix(matrix, axes='sxyz'):
|
1265 |
+
"""Return Euler angles from rotation matrix for specified axis sequence.
|
1266 |
+
|
1267 |
+
axes : One of 24 axis sequences as string or encoded tuple
|
1268 |
+
|
1269 |
+
Note that many Euler angle triplets can describe one matrix.
|
1270 |
+
|
1271 |
+
>>> R0 = euler_matrix(1, 2, 3, 'syxz')
|
1272 |
+
>>> al, be, ga = euler_from_matrix(R0, 'syxz')
|
1273 |
+
>>> R1 = euler_matrix(al, be, ga, 'syxz')
|
1274 |
+
>>> numpy.allclose(R0, R1)
|
1275 |
+
True
|
1276 |
+
>>> angles = (4.0*math.pi) * (numpy.random.random(3) - 0.5)
|
1277 |
+
>>> for axes in _AXES2TUPLE.keys():
|
1278 |
+
... R0 = euler_matrix(axes=axes, *angles)
|
1279 |
+
... R1 = euler_matrix(axes=axes, *euler_from_matrix(R0, axes))
|
1280 |
+
... if not numpy.allclose(R0, R1): print axes, "failed"
|
1281 |
+
|
1282 |
+
"""
|
1283 |
+
try:
|
1284 |
+
firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()]
|
1285 |
+
except (AttributeError, KeyError):
|
1286 |
+
_ = _TUPLE2AXES[axes]
|
1287 |
+
firstaxis, parity, repetition, frame = axes
|
1288 |
+
|
1289 |
+
i = firstaxis
|
1290 |
+
j = _NEXT_AXIS[i+parity]
|
1291 |
+
k = _NEXT_AXIS[i-parity+1]
|
1292 |
+
|
1293 |
+
M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:3, :3]
|
1294 |
+
if repetition:
|
1295 |
+
sy = math.sqrt(M[i, j]*M[i, j] + M[i, k]*M[i, k])
|
1296 |
+
if sy > _EPS:
|
1297 |
+
ax = math.atan2(M[i, j], M[i, k])
|
1298 |
+
ay = math.atan2(sy, M[i, i])
|
1299 |
+
az = math.atan2(M[j, i], -M[k, i])
|
1300 |
+
else:
|
1301 |
+
ax = math.atan2(-M[j, k], M[j, j])
|
1302 |
+
ay = math.atan2(sy, M[i, i])
|
1303 |
+
az = 0.0
|
1304 |
+
else:
|
1305 |
+
cy = math.sqrt(M[i, i]*M[i, i] + M[j, i]*M[j, i])
|
1306 |
+
if cy > _EPS:
|
1307 |
+
ax = math.atan2(M[k, j], M[k, k])
|
1308 |
+
ay = math.atan2(-M[k, i], cy)
|
1309 |
+
az = math.atan2(M[j, i], M[i, i])
|
1310 |
+
else:
|
1311 |
+
ax = math.atan2(-M[j, k], M[j, j])
|
1312 |
+
ay = math.atan2(-M[k, i], cy)
|
1313 |
+
az = 0.0
|
1314 |
+
|
1315 |
+
if parity:
|
1316 |
+
ax, ay, az = -ax, -ay, -az
|
1317 |
+
if frame:
|
1318 |
+
ax, az = az, ax
|
1319 |
+
return ax, ay, az
|
1320 |
+
|
1321 |
+
|
1322 |
+
def euler_from_quaternion(quaternion, axes='sxyz'):
|
1323 |
+
"""Return Euler angles from quaternion for specified axis sequence.
|
1324 |
+
|
1325 |
+
>>> angles = euler_from_quaternion([0.06146124, 0, 0, 0.99810947])
|
1326 |
+
>>> numpy.allclose(angles, [0.123, 0, 0])
|
1327 |
+
True
|
1328 |
+
|
1329 |
+
"""
|
1330 |
+
return euler_from_matrix(quaternion_matrix(quaternion), axes)
|
1331 |
+
|
1332 |
+
|
1333 |
+
def quaternion_from_euler(ai, aj, ak, axes='sxyz'):
|
1334 |
+
"""Return quaternion from Euler angles and axis sequence.
|
1335 |
+
|
1336 |
+
ai, aj, ak : Euler's roll, pitch and yaw angles
|
1337 |
+
axes : One of 24 axis sequences as string or encoded tuple
|
1338 |
+
|
1339 |
+
>>> q = quaternion_from_euler(1, 2, 3, 'ryxz')
|
1340 |
+
>>> numpy.allclose(q, [0.310622, -0.718287, 0.444435, 0.435953])
|
1341 |
+
True
|
1342 |
+
|
1343 |
+
"""
|
1344 |
+
try:
|
1345 |
+
firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()]
|
1346 |
+
except (AttributeError, KeyError):
|
1347 |
+
_ = _TUPLE2AXES[axes]
|
1348 |
+
firstaxis, parity, repetition, frame = axes
|
1349 |
+
|
1350 |
+
i = firstaxis
|
1351 |
+
j = _NEXT_AXIS[i+parity]
|
1352 |
+
k = _NEXT_AXIS[i-parity+1]
|
1353 |
+
|
1354 |
+
if frame:
|
1355 |
+
ai, ak = ak, ai
|
1356 |
+
if parity:
|
1357 |
+
aj = -aj
|
1358 |
+
|
1359 |
+
ai /= 2.0
|
1360 |
+
aj /= 2.0
|
1361 |
+
ak /= 2.0
|
1362 |
+
ci = math.cos(ai)
|
1363 |
+
si = math.sin(ai)
|
1364 |
+
cj = math.cos(aj)
|
1365 |
+
sj = math.sin(aj)
|
1366 |
+
ck = math.cos(ak)
|
1367 |
+
sk = math.sin(ak)
|
1368 |
+
cc = ci*ck
|
1369 |
+
cs = ci*sk
|
1370 |
+
sc = si*ck
|
1371 |
+
ss = si*sk
|
1372 |
+
|
1373 |
+
quaternion = numpy.empty((4, ), dtype=numpy.float64)
|
1374 |
+
if repetition:
|
1375 |
+
quaternion[i] = cj*(cs + sc)
|
1376 |
+
quaternion[j] = sj*(cc + ss)
|
1377 |
+
quaternion[k] = sj*(cs - sc)
|
1378 |
+
quaternion[3] = cj*(cc - ss)
|
1379 |
+
else:
|
1380 |
+
quaternion[i] = cj*sc - sj*cs
|
1381 |
+
quaternion[j] = cj*ss + sj*cc
|
1382 |
+
quaternion[k] = cj*cs - sj*sc
|
1383 |
+
quaternion[3] = cj*cc + sj*ss
|
1384 |
+
if parity:
|
1385 |
+
quaternion[j] *= -1
|
1386 |
+
|
1387 |
+
return quaternion
|
1388 |
+
|
1389 |
+
|
1390 |
+
def quaternion_about_axis(angle, axis):
|
1391 |
+
"""Return quaternion for rotation about axis.
|
1392 |
+
|
1393 |
+
>>> q = quaternion_about_axis(0.123, (1, 0, 0))
|
1394 |
+
>>> numpy.allclose(q, [0.06146124, 0, 0, 0.99810947])
|
1395 |
+
True
|
1396 |
+
|
1397 |
+
"""
|
1398 |
+
quaternion = numpy.zeros((4, ), dtype=numpy.float64)
|
1399 |
+
quaternion[:3] = axis[:3]
|
1400 |
+
qlen = vector_norm(quaternion)
|
1401 |
+
if qlen > _EPS:
|
1402 |
+
quaternion *= math.sin(angle/2.0) / qlen
|
1403 |
+
quaternion[3] = math.cos(angle/2.0)
|
1404 |
+
return quaternion
|
1405 |
+
|
1406 |
+
|
1407 |
+
def matrix_from_quaternion(quaternion):
|
1408 |
+
return quaternion_matrix(quaternion)
|
1409 |
+
|
1410 |
+
|
1411 |
+
def quaternion_matrix(quaternion):
|
1412 |
+
"""Return homogeneous rotation matrix from quaternion.
|
1413 |
+
|
1414 |
+
>>> R = quaternion_matrix([0.06146124, 0, 0, 0.99810947])
|
1415 |
+
>>> numpy.allclose(R, rotation_matrix(0.123, (1, 0, 0)))
|
1416 |
+
True
|
1417 |
+
|
1418 |
+
"""
|
1419 |
+
q = numpy.array(quaternion[:4], dtype=numpy.float64, copy=True)
|
1420 |
+
nq = numpy.dot(q, q)
|
1421 |
+
if nq < _EPS:
|
1422 |
+
return numpy.identity(4)
|
1423 |
+
q *= math.sqrt(2.0 / nq)
|
1424 |
+
q = numpy.outer(q, q)
|
1425 |
+
return numpy.array((
|
1426 |
+
(1.0-q[1, 1]-q[2, 2], q[0, 1]-q[2, 3], q[0, 2]+q[1, 3], 0.0),
|
1427 |
+
(q[0, 1]+q[2, 3], 1.0-q[0, 0]-q[2, 2], q[1, 2]-q[0, 3], 0.0),
|
1428 |
+
(q[0, 2]-q[1, 3], q[1, 2]+q[0, 3], 1.0-q[0, 0]-q[1, 1], 0.0),
|
1429 |
+
(0.0, 0.0, 0.0, 1.0)
|
1430 |
+
), dtype=numpy.float64)
|
1431 |
+
|
1432 |
+
|
1433 |
+
def quaternionJPL_matrix(quaternion):
|
1434 |
+
"""Return homogeneous rotation matrix from quaternion in JPL notation.
|
1435 |
+
quaternion = [x y z w]
|
1436 |
+
"""
|
1437 |
+
q0 = quaternion[0]
|
1438 |
+
q1 = quaternion[1]
|
1439 |
+
q2 = quaternion[2]
|
1440 |
+
q3 = quaternion[3]
|
1441 |
+
return numpy.array([
|
1442 |
+
[q0**2 - q1**2 - q2**2 + q3**2, 2.0*q0*q1 +
|
1443 |
+
2.0*q2*q3, 2.0*q0*q2 - 2.0*q1*q3, 0],
|
1444 |
+
[2.0*q0*q1 - 2.0*q2*q3, - q0**2 + q1**2 -
|
1445 |
+
q2**2 + q3**2, 2.0*q0*q3 + 2.0*q1*q2, 0],
|
1446 |
+
[2.0*q0*q2 + 2.0*q1*q3, 2.0*q1*q2 - 2.0*q0 *
|
1447 |
+
q3, - q0**2 - q1**2 + q2**2 + q3**2, 0],
|
1448 |
+
[0, 0, 0, 1.0]], dtype=numpy.float64)
|
1449 |
+
|
1450 |
+
|
1451 |
+
def quaternion_from_matrix(matrix):
|
1452 |
+
"""Return quaternion from rotation matrix.
|
1453 |
+
|
1454 |
+
>>> R = rotation_matrix(0.123, (1, 2, 3))
|
1455 |
+
>>> q = quaternion_from_matrix(R)
|
1456 |
+
>>> numpy.allclose(q, [0.0164262, 0.0328524, 0.0492786, 0.9981095])
|
1457 |
+
True
|
1458 |
+
|
1459 |
+
"""
|
1460 |
+
q = numpy.empty((4, ), dtype=numpy.float64)
|
1461 |
+
M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:4, :4]
|
1462 |
+
t = numpy.trace(M)
|
1463 |
+
if t > M[3, 3]:
|
1464 |
+
q[3] = t
|
1465 |
+
q[2] = M[1, 0] - M[0, 1]
|
1466 |
+
q[1] = M[0, 2] - M[2, 0]
|
1467 |
+
q[0] = M[2, 1] - M[1, 2]
|
1468 |
+
else:
|
1469 |
+
i, j, k = 0, 1, 2
|
1470 |
+
if M[1, 1] > M[0, 0]:
|
1471 |
+
i, j, k = 1, 2, 0
|
1472 |
+
if M[2, 2] > M[i, i]:
|
1473 |
+
i, j, k = 2, 0, 1
|
1474 |
+
t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3]
|
1475 |
+
q[i] = t
|
1476 |
+
q[j] = M[i, j] + M[j, i]
|
1477 |
+
q[k] = M[k, i] + M[i, k]
|
1478 |
+
q[3] = M[k, j] - M[j, k]
|
1479 |
+
q *= 0.5 / math.sqrt(t * M[3, 3])
|
1480 |
+
return q
|
1481 |
+
|
1482 |
+
|
1483 |
+
def quaternion_multiply(quaternion1, quaternion0):
|
1484 |
+
"""Return multiplication of two quaternions.
|
1485 |
+
|
1486 |
+
>>> q = quaternion_multiply([1, -2, 3, 4], [-5, 6, 7, 8])
|
1487 |
+
>>> numpy.allclose(q, [-44, -14, 48, 28])
|
1488 |
+
True
|
1489 |
+
|
1490 |
+
"""
|
1491 |
+
x0, y0, z0, w0 = quaternion0
|
1492 |
+
x1, y1, z1, w1 = quaternion1
|
1493 |
+
return numpy.array((
|
1494 |
+
x1*w0 + y1*z0 - z1*y0 + w1*x0,
|
1495 |
+
-x1*z0 + y1*w0 + z1*x0 + w1*y0,
|
1496 |
+
x1*y0 - y1*x0 + z1*w0 + w1*z0,
|
1497 |
+
-x1*x0 - y1*y0 - z1*z0 + w1*w0), dtype=numpy.float64)
|
1498 |
+
|
1499 |
+
|
1500 |
+
def quaternion_conjugate(quaternion):
|
1501 |
+
"""Return conjugate of quaternion.
|
1502 |
+
|
1503 |
+
>>> q0 = random_quaternion()
|
1504 |
+
>>> q1 = quaternion_conjugate(q0)
|
1505 |
+
>>> q1[3] == q0[3] and all(q1[:3] == -q0[:3])
|
1506 |
+
True
|
1507 |
+
|
1508 |
+
"""
|
1509 |
+
return numpy.array((-quaternion[0], -quaternion[1],
|
1510 |
+
-quaternion[2], quaternion[3]), dtype=numpy.float64)
|
1511 |
+
|
1512 |
+
|
1513 |
+
def quaternion_inverse(quaternion):
|
1514 |
+
"""Return inverse of quaternion.
|
1515 |
+
|
1516 |
+
>>> q0 = random_quaternion()
|
1517 |
+
>>> q1 = quaternion_inverse(q0)
|
1518 |
+
>>> numpy.allclose(quaternion_multiply(q0, q1), [0, 0, 0, 1])
|
1519 |
+
True
|
1520 |
+
|
1521 |
+
"""
|
1522 |
+
return quaternion_conjugate(quaternion) / numpy.dot(quaternion, quaternion)
|
1523 |
+
|
1524 |
+
|
1525 |
+
def quaternion_slerp(quat0, quat1, fraction, spin=0, shortestpath=True):
|
1526 |
+
"""Return spherical linear interpolation between two quaternions.
|
1527 |
+
|
1528 |
+
>>> q0 = random_quaternion()
|
1529 |
+
>>> q1 = random_quaternion()
|
1530 |
+
>>> q = quaternion_slerp(q0, q1, 0.0)
|
1531 |
+
>>> numpy.allclose(q, q0)
|
1532 |
+
True
|
1533 |
+
>>> q = quaternion_slerp(q0, q1, 1.0, 1)
|
1534 |
+
>>> numpy.allclose(q, q1)
|
1535 |
+
True
|
1536 |
+
>>> q = quaternion_slerp(q0, q1, 0.5)
|
1537 |
+
>>> angle = math.acos(numpy.dot(q0, q))
|
1538 |
+
>>> numpy.allclose(2.0, math.acos(numpy.dot(q0, q1)) / angle) or \
|
1539 |
+
numpy.allclose(2.0, math.acos(-numpy.dot(q0, q1)) / angle)
|
1540 |
+
True
|
1541 |
+
|
1542 |
+
"""
|
1543 |
+
q0 = unit_vector(quat0[:4])
|
1544 |
+
q1 = unit_vector(quat1[:4])
|
1545 |
+
if fraction == 0.0:
|
1546 |
+
return q0
|
1547 |
+
elif fraction == 1.0:
|
1548 |
+
return q1
|
1549 |
+
d = numpy.dot(q0, q1)
|
1550 |
+
if abs(abs(d) - 1.0) < _EPS:
|
1551 |
+
return q0
|
1552 |
+
if shortestpath and d < 0.0:
|
1553 |
+
# invert rotation
|
1554 |
+
d = -d
|
1555 |
+
q1 *= -1.0
|
1556 |
+
angle = math.acos(d) + spin * math.pi
|
1557 |
+
if abs(angle) < _EPS:
|
1558 |
+
return q0
|
1559 |
+
isin = 1.0 / math.sin(angle)
|
1560 |
+
q0 *= math.sin((1.0 - fraction) * angle) * isin
|
1561 |
+
q1 *= math.sin(fraction * angle) * isin
|
1562 |
+
q0 += q1
|
1563 |
+
return q0
|
1564 |
+
|
1565 |
+
|
1566 |
+
def random_quaternion(rand=None):
|
1567 |
+
"""Return uniform random unit quaternion.
|
1568 |
+
|
1569 |
+
rand: array like or None
|
1570 |
+
Three independent random variables that are uniformly distributed
|
1571 |
+
between 0 and 1.
|
1572 |
+
|
1573 |
+
>>> q = random_quaternion()
|
1574 |
+
>>> numpy.allclose(1.0, vector_norm(q))
|
1575 |
+
True
|
1576 |
+
>>> q = random_quaternion(numpy.random.random(3))
|
1577 |
+
>>> q.shape
|
1578 |
+
(4,)
|
1579 |
+
|
1580 |
+
"""
|
1581 |
+
if rand is None:
|
1582 |
+
rand = numpy.random.rand(3)
|
1583 |
+
else:
|
1584 |
+
assert len(rand) == 3
|
1585 |
+
r1 = numpy.sqrt(1.0 - rand[0])
|
1586 |
+
r2 = numpy.sqrt(rand[0])
|
1587 |
+
pi2 = math.pi * 2.0
|
1588 |
+
t1 = pi2 * rand[1]
|
1589 |
+
t2 = pi2 * rand[2]
|
1590 |
+
return numpy.array((numpy.sin(t1)*r1,
|
1591 |
+
numpy.cos(t1)*r1,
|
1592 |
+
numpy.sin(t2)*r2,
|
1593 |
+
numpy.cos(t2)*r2), dtype=numpy.float64)
|
1594 |
+
|
1595 |
+
|
1596 |
+
def random_rotation_matrix(rand=None):
|
1597 |
+
"""Return uniform random rotation matrix.
|
1598 |
+
|
1599 |
+
rnd: array like
|
1600 |
+
Three independent random variables that are uniformly distributed
|
1601 |
+
between 0 and 1 for each returned quaternion.
|
1602 |
+
|
1603 |
+
>>> R = random_rotation_matrix()
|
1604 |
+
>>> numpy.allclose(numpy.dot(R.T, R), numpy.identity(4))
|
1605 |
+
True
|
1606 |
+
|
1607 |
+
"""
|
1608 |
+
return quaternion_matrix(random_quaternion(rand))
|
1609 |
+
|
1610 |
+
|
1611 |
+
def random_direction_3d():
|
1612 |
+
""" equal-area projection according to:
|
1613 |
+
https://math.stackexchange.com/questions/44689/how-to-find-a-random-axis-or-unit-vector-in-3d
|
1614 |
+
cfo, 2015/10/16
|
1615 |
+
"""
|
1616 |
+
z = numpy.random.rand() * 2.0 - 1.0
|
1617 |
+
t = numpy.random.rand() * 2.0 * numpy.pi
|
1618 |
+
r = numpy.sqrt(1.0 - z*z)
|
1619 |
+
x = r * numpy.cos(t)
|
1620 |
+
y = r * numpy.sin(t)
|
1621 |
+
return numpy.array([x, y, z], dtype=numpy.float64)
|
1622 |
+
|
1623 |
+
|
1624 |
+
class Arcball(object):
|
1625 |
+
"""Virtual Trackball Control.
|
1626 |
+
|
1627 |
+
>>> ball = Arcball()
|
1628 |
+
>>> ball = Arcball(initial=numpy.identity(4))
|
1629 |
+
>>> ball.place([320, 320], 320)
|
1630 |
+
>>> ball.down([500, 250])
|
1631 |
+
>>> ball.drag([475, 275])
|
1632 |
+
>>> R = ball.matrix()
|
1633 |
+
>>> numpy.allclose(numpy.sum(R), 3.90583455)
|
1634 |
+
True
|
1635 |
+
>>> ball = Arcball(initial=[0, 0, 0, 1])
|
1636 |
+
>>> ball.place([320, 320], 320)
|
1637 |
+
>>> ball.setaxes([1,1,0], [-1, 1, 0])
|
1638 |
+
>>> ball.setconstrain(True)
|
1639 |
+
>>> ball.down([400, 200])
|
1640 |
+
>>> ball.drag([200, 400])
|
1641 |
+
>>> R = ball.matrix()
|
1642 |
+
>>> numpy.allclose(numpy.sum(R), 0.2055924)
|
1643 |
+
True
|
1644 |
+
>>> ball.next()
|
1645 |
+
|
1646 |
+
"""
|
1647 |
+
|
1648 |
+
def __init__(self, initial=None):
|
1649 |
+
"""Initialize virtual trackball control.
|
1650 |
+
|
1651 |
+
initial : quaternion or rotation matrix
|
1652 |
+
|
1653 |
+
"""
|
1654 |
+
self._axis = None
|
1655 |
+
self._axes = None
|
1656 |
+
self._radius = 1.0
|
1657 |
+
self._center = [0.0, 0.0]
|
1658 |
+
self._vdown = numpy.array([0, 0, 1], dtype=numpy.float64)
|
1659 |
+
self._constrain = False
|
1660 |
+
|
1661 |
+
if initial is None:
|
1662 |
+
self._qdown = numpy.array([0, 0, 0, 1], dtype=numpy.float64)
|
1663 |
+
else:
|
1664 |
+
initial = numpy.array(initial, dtype=numpy.float64)
|
1665 |
+
if initial.shape == (4, 4):
|
1666 |
+
self._qdown = quaternion_from_matrix(initial)
|
1667 |
+
elif initial.shape == (4, ):
|
1668 |
+
initial /= vector_norm(initial)
|
1669 |
+
self._qdown = initial
|
1670 |
+
else:
|
1671 |
+
raise ValueError("initial not a quaternion or matrix.")
|
1672 |
+
|
1673 |
+
self._qnow = self._qpre = self._qdown
|
1674 |
+
|
1675 |
+
def place(self, center, radius):
|
1676 |
+
"""Place Arcball, e.g. when window size changes.
|
1677 |
+
|
1678 |
+
center : sequence[2]
|
1679 |
+
Window coordinates of trackball center.
|
1680 |
+
radius : float
|
1681 |
+
Radius of trackball in window coordinates.
|
1682 |
+
|
1683 |
+
"""
|
1684 |
+
self._radius = float(radius)
|
1685 |
+
self._center[0] = center[0]
|
1686 |
+
self._center[1] = center[1]
|
1687 |
+
|
1688 |
+
def setaxes(self, *axes):
|
1689 |
+
"""Set axes to constrain rotations."""
|
1690 |
+
if axes is None:
|
1691 |
+
self._axes = None
|
1692 |
+
else:
|
1693 |
+
self._axes = [unit_vector(axis) for axis in axes]
|
1694 |
+
|
1695 |
+
def setconstrain(self, constrain):
|
1696 |
+
"""Set state of constrain to axis mode."""
|
1697 |
+
self._constrain = constrain == True
|
1698 |
+
|
1699 |
+
def getconstrain(self):
|
1700 |
+
"""Return state of constrain to axis mode."""
|
1701 |
+
return self._constrain
|
1702 |
+
|
1703 |
+
def down(self, point):
|
1704 |
+
"""Set initial cursor window coordinates and pick constrain-axis."""
|
1705 |
+
self._vdown = arcball_map_to_sphere(point, self._center, self._radius)
|
1706 |
+
self._qdown = self._qpre = self._qnow
|
1707 |
+
|
1708 |
+
if self._constrain and self._axes is not None:
|
1709 |
+
self._axis = arcball_nearest_axis(self._vdown, self._axes)
|
1710 |
+
self._vdown = arcball_constrain_to_axis(self._vdown, self._axis)
|
1711 |
+
else:
|
1712 |
+
self._axis = None
|
1713 |
+
|
1714 |
+
def drag(self, point):
|
1715 |
+
"""Update current cursor window coordinates."""
|
1716 |
+
vnow = arcball_map_to_sphere(point, self._center, self._radius)
|
1717 |
+
|
1718 |
+
if self._axis is not None:
|
1719 |
+
vnow = arcball_constrain_to_axis(vnow, self._axis)
|
1720 |
+
|
1721 |
+
self._qpre = self._qnow
|
1722 |
+
|
1723 |
+
t = numpy.cross(self._vdown, vnow)
|
1724 |
+
if numpy.dot(t, t) < _EPS:
|
1725 |
+
self._qnow = self._qdown
|
1726 |
+
else:
|
1727 |
+
q = [t[0], t[1], t[2], numpy.dot(self._vdown, vnow)]
|
1728 |
+
self._qnow = quaternion_multiply(q, self._qdown)
|
1729 |
+
|
1730 |
+
def next(self, acceleration=0.0):
|
1731 |
+
"""Continue rotation in direction of last drag."""
|
1732 |
+
q = quaternion_slerp(self._qpre, self._qnow, 2.0+acceleration, False)
|
1733 |
+
self._qpre, self._qnow = self._qnow, q
|
1734 |
+
|
1735 |
+
def matrix(self):
|
1736 |
+
"""Return homogeneous rotation matrix."""
|
1737 |
+
return quaternion_matrix(self._qnow)
|
1738 |
+
|
1739 |
+
|
1740 |
+
def arcball_map_to_sphere(point, center, radius):
|
1741 |
+
"""Return unit sphere coordinates from window coordinates."""
|
1742 |
+
v = numpy.array(((point[0] - center[0]) / radius,
|
1743 |
+
(center[1] - point[1]) / radius,
|
1744 |
+
0.0), dtype=numpy.float64)
|
1745 |
+
n = v[0]*v[0] + v[1]*v[1]
|
1746 |
+
if n > 1.0:
|
1747 |
+
v /= math.sqrt(n) # position outside of sphere
|
1748 |
+
else:
|
1749 |
+
v[2] = math.sqrt(1.0 - n)
|
1750 |
+
return v
|
1751 |
+
|
1752 |
+
|
1753 |
+
def arcball_constrain_to_axis(point, axis):
|
1754 |
+
"""Return sphere point perpendicular to axis."""
|
1755 |
+
v = numpy.array(point, dtype=numpy.float64, copy=True)
|
1756 |
+
a = numpy.array(axis, dtype=numpy.float64, copy=True)
|
1757 |
+
v -= a * numpy.dot(a, v) # on plane
|
1758 |
+
n = vector_norm(v)
|
1759 |
+
if n > _EPS:
|
1760 |
+
if v[2] < 0.0:
|
1761 |
+
v *= -1.0
|
1762 |
+
v /= n
|
1763 |
+
return v
|
1764 |
+
if a[2] == 1.0:
|
1765 |
+
return numpy.array([1, 0, 0], dtype=numpy.float64)
|
1766 |
+
return unit_vector([-a[1], a[0], 0])
|
1767 |
+
|
1768 |
+
|
1769 |
+
def arcball_nearest_axis(point, axes):
|
1770 |
+
"""Return axis, which arc is nearest to point."""
|
1771 |
+
point = numpy.array(point, dtype=numpy.float64, copy=False)
|
1772 |
+
nearest = None
|
1773 |
+
mx = -1.0
|
1774 |
+
for axis in axes:
|
1775 |
+
t = numpy.dot(arcball_constrain_to_axis(point, axis), point)
|
1776 |
+
if t > mx:
|
1777 |
+
nearest = axis
|
1778 |
+
mx = t
|
1779 |
+
return nearest
|
1780 |
+
|
1781 |
+
|
1782 |
+
# epsilon for testing whether a number is close to zero
|
1783 |
+
_EPS = numpy.finfo(float).eps * 4.0
|
1784 |
+
|
1785 |
+
# axis sequences for Euler angles
|
1786 |
+
_NEXT_AXIS = [1, 2, 0, 1]
|
1787 |
+
|
1788 |
+
# map axes strings to/from tuples of inner axis, parity, repetition, frame
|
1789 |
+
_AXES2TUPLE = {
|
1790 |
+
'sxyz': (0, 0, 0, 0), 'sxyx': (0, 0, 1, 0), 'sxzy': (0, 1, 0, 0),
|
1791 |
+
'sxzx': (0, 1, 1, 0), 'syzx': (1, 0, 0, 0), 'syzy': (1, 0, 1, 0),
|
1792 |
+
'syxz': (1, 1, 0, 0), 'syxy': (1, 1, 1, 0), 'szxy': (2, 0, 0, 0),
|
1793 |
+
'szxz': (2, 0, 1, 0), 'szyx': (2, 1, 0, 0), 'szyz': (2, 1, 1, 0),
|
1794 |
+
'rzyx': (0, 0, 0, 1), 'rxyx': (0, 0, 1, 1), 'ryzx': (0, 1, 0, 1),
|
1795 |
+
'rxzx': (0, 1, 1, 1), 'rxzy': (1, 0, 0, 1), 'ryzy': (1, 0, 1, 1),
|
1796 |
+
'rzxy': (1, 1, 0, 1), 'ryxy': (1, 1, 1, 1), 'ryxz': (2, 0, 0, 1),
|
1797 |
+
'rzxz': (2, 0, 1, 1), 'rxyz': (2, 1, 0, 1), 'rzyz': (2, 1, 1, 1)}
|
1798 |
+
|
1799 |
+
_TUPLE2AXES = dict((v, k) for k, v in _AXES2TUPLE.items())
|
1800 |
+
|
1801 |
+
# helper functions
|
1802 |
+
|
1803 |
+
|
1804 |
+
def vector_norm(data, axis=None, out=None):
|
1805 |
+
"""Return length, i.e. eucledian norm, of ndarray along axis.
|
1806 |
+
|
1807 |
+
>>> v = numpy.random.random(3)
|
1808 |
+
>>> n = vector_norm(v)
|
1809 |
+
>>> numpy.allclose(n, numpy.linalg.norm(v))
|
1810 |
+
True
|
1811 |
+
>>> v = numpy.random.rand(6, 5, 3)
|
1812 |
+
>>> n = vector_norm(v, axis=-1)
|
1813 |
+
>>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=2)))
|
1814 |
+
True
|
1815 |
+
>>> n = vector_norm(v, axis=1)
|
1816 |
+
>>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=1)))
|
1817 |
+
True
|
1818 |
+
>>> v = numpy.random.rand(5, 4, 3)
|
1819 |
+
>>> n = numpy.empty((5, 3), dtype=numpy.float64)
|
1820 |
+
>>> vector_norm(v, axis=1, out=n)
|
1821 |
+
>>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=1)))
|
1822 |
+
True
|
1823 |
+
>>> vector_norm([])
|
1824 |
+
0.0
|
1825 |
+
>>> vector_norm([1.0])
|
1826 |
+
1.0
|
1827 |
+
|
1828 |
+
"""
|
1829 |
+
data = numpy.array(data, dtype=numpy.float64, copy=True)
|
1830 |
+
if out is None:
|
1831 |
+
if data.ndim == 1:
|
1832 |
+
return math.sqrt(numpy.dot(data, data))
|
1833 |
+
data *= data
|
1834 |
+
out = numpy.atleast_1d(numpy.sum(data, axis=axis))
|
1835 |
+
numpy.sqrt(out, out)
|
1836 |
+
return out
|
1837 |
+
else:
|
1838 |
+
data *= data
|
1839 |
+
numpy.sum(data, axis=axis, out=out)
|
1840 |
+
numpy.sqrt(out, out)
|
1841 |
+
|
1842 |
+
|
1843 |
+
def unit_vector(data, axis=None, out=None):
|
1844 |
+
"""Return ndarray normalized by length, i.e. eucledian norm, along axis.
|
1845 |
+
|
1846 |
+
>>> v0 = numpy.random.random(3)
|
1847 |
+
>>> v1 = unit_vector(v0)
|
1848 |
+
>>> numpy.allclose(v1, v0 / numpy.linalg.norm(v0))
|
1849 |
+
True
|
1850 |
+
>>> v0 = numpy.random.rand(5, 4, 3)
|
1851 |
+
>>> v1 = unit_vector(v0, axis=-1)
|
1852 |
+
>>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=2)), 2)
|
1853 |
+
>>> numpy.allclose(v1, v2)
|
1854 |
+
True
|
1855 |
+
>>> v1 = unit_vector(v0, axis=1)
|
1856 |
+
>>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=1)), 1)
|
1857 |
+
>>> numpy.allclose(v1, v2)
|
1858 |
+
True
|
1859 |
+
>>> v1 = numpy.empty((5, 4, 3), dtype=numpy.float64)
|
1860 |
+
>>> unit_vector(v0, axis=1, out=v1)
|
1861 |
+
>>> numpy.allclose(v1, v2)
|
1862 |
+
True
|
1863 |
+
>>> list(unit_vector([]))
|
1864 |
+
[]
|
1865 |
+
>>> list(unit_vector([1.0]))
|
1866 |
+
[1.0]
|
1867 |
+
|
1868 |
+
"""
|
1869 |
+
if out is None:
|
1870 |
+
data = numpy.array(data, dtype=numpy.float64, copy=True)
|
1871 |
+
if data.ndim == 1:
|
1872 |
+
data /= math.sqrt(numpy.dot(data, data))
|
1873 |
+
return data
|
1874 |
+
else:
|
1875 |
+
if out is not data:
|
1876 |
+
out[:] = numpy.array(data, copy=False)
|
1877 |
+
data = out
|
1878 |
+
length = numpy.atleast_1d(numpy.sum(data*data, axis))
|
1879 |
+
numpy.sqrt(length, length)
|
1880 |
+
if axis is not None:
|
1881 |
+
length = numpy.expand_dims(length, axis)
|
1882 |
+
data /= length
|
1883 |
+
if out is None:
|
1884 |
+
return data
|
1885 |
+
|
1886 |
+
|
1887 |
+
def random_vector(size):
|
1888 |
+
"""Return array of random doubles in the half-open interval [0.0, 1.0).
|
1889 |
+
|
1890 |
+
>>> v = random_vector(10000)
|
1891 |
+
>>> numpy.all(v >= 0.0) and numpy.all(v < 1.0)
|
1892 |
+
True
|
1893 |
+
>>> v0 = random_vector(10)
|
1894 |
+
>>> v1 = random_vector(10)
|
1895 |
+
>>> numpy.any(v0 == v1)
|
1896 |
+
False
|
1897 |
+
|
1898 |
+
"""
|
1899 |
+
return numpy.random.random(size)
|
1900 |
+
|
1901 |
+
|
1902 |
+
def inverse_matrix(matrix):
|
1903 |
+
"""Return inverse of square transformation matrix.
|
1904 |
+
|
1905 |
+
>>> M0 = random_rotation_matrix()
|
1906 |
+
>>> M1 = inverse_matrix(M0.T)
|
1907 |
+
>>> numpy.allclose(M1, numpy.linalg.inv(M0.T))
|
1908 |
+
True
|
1909 |
+
>>> for size in range(1, 7):
|
1910 |
+
... M0 = numpy.random.rand(size, size)
|
1911 |
+
... M1 = inverse_matrix(M0)
|
1912 |
+
... if not numpy.allclose(M1, numpy.linalg.inv(M0)): print size
|
1913 |
+
|
1914 |
+
"""
|
1915 |
+
return numpy.linalg.inv(matrix)
|
1916 |
+
|
1917 |
+
|
1918 |
+
def concatenate_matrices(*matrices):
|
1919 |
+
"""Return concatenation of series of transformation matrices.
|
1920 |
+
|
1921 |
+
>>> M = numpy.random.rand(16).reshape((4, 4)) - 0.5
|
1922 |
+
>>> numpy.allclose(M, concatenate_matrices(M))
|
1923 |
+
True
|
1924 |
+
>>> numpy.allclose(numpy.dot(M, M.T), concatenate_matrices(M, M.T))
|
1925 |
+
True
|
1926 |
+
|
1927 |
+
"""
|
1928 |
+
M = numpy.identity(4)
|
1929 |
+
for i in matrices:
|
1930 |
+
M = numpy.dot(M, i)
|
1931 |
+
return M
|
1932 |
+
|
1933 |
+
|
1934 |
+
def is_same_transform(matrix0, matrix1):
|
1935 |
+
"""Return True if two matrices perform same transformation.
|
1936 |
+
|
1937 |
+
>>> is_same_transform(numpy.identity(4), numpy.identity(4))
|
1938 |
+
True
|
1939 |
+
>>> is_same_transform(numpy.identity(4), random_rotation_matrix())
|
1940 |
+
False
|
1941 |
+
|
1942 |
+
"""
|
1943 |
+
matrix0 = numpy.array(matrix0, dtype=numpy.float64, copy=True)
|
1944 |
+
matrix0 /= matrix0[3, 3]
|
1945 |
+
matrix1 = numpy.array(matrix1, dtype=numpy.float64, copy=True)
|
1946 |
+
matrix1 /= matrix1[3, 3]
|
1947 |
+
return numpy.allclose(matrix0, matrix1)
|
1948 |
+
|
1949 |
+
|
1950 |
+
def _import_module(module_name, warn=True, prefix='_py_', ignore='_'):
|
1951 |
+
"""Try import all public attributes from module into global namespace.
|
1952 |
+
|
1953 |
+
Existing attributes with name clashes are renamed with prefix.
|
1954 |
+
Attributes starting with underscore are ignored by default.
|
1955 |
+
|
1956 |
+
Return True on successful import.
|
1957 |
+
|
1958 |
+
"""
|
1959 |
+
try:
|
1960 |
+
module = __import__(module_name)
|
1961 |
+
except ImportError:
|
1962 |
+
if warn:
|
1963 |
+
warnings.warn("Failed to import module " + module_name)
|
1964 |
+
else:
|
1965 |
+
for attr in dir(module):
|
1966 |
+
if ignore and attr.startswith(ignore):
|
1967 |
+
continue
|
1968 |
+
if prefix:
|
1969 |
+
if attr in globals():
|
1970 |
+
globals()[prefix + attr] = globals()[attr]
|
1971 |
+
elif warn:
|
1972 |
+
warnings.warn("No Python implementation of " + attr)
|
1973 |
+
globals()[attr] = getattr(module, attr)
|
1974 |
+
return True
|
utils/utils_poses/align_traj.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from utils.utils_poses.ATE.align_utils import alignTrajectory
|
5 |
+
from utils.utils_poses.lie_group_helper import SO3_to_quat, convert3x4_4x4
|
6 |
+
|
7 |
+
|
8 |
+
def pts_dist_max(pts):
|
9 |
+
"""
|
10 |
+
:param pts: (N, 3) torch or np
|
11 |
+
:return: scalar
|
12 |
+
"""
|
13 |
+
if torch.is_tensor(pts):
|
14 |
+
dist = pts.unsqueeze(0) - pts.unsqueeze(1) # (1, N, 3) - (N, 1, 3) -> (N, N, 3)
|
15 |
+
dist = dist[0] # (N, 3)
|
16 |
+
dist = dist.norm(dim=1) # (N, )
|
17 |
+
max_dist = dist.max()
|
18 |
+
else:
|
19 |
+
dist = pts[None, :, :] - pts[:, None, :] # (1, N, 3) - (N, 1, 3) -> (N, N, 3)
|
20 |
+
dist = dist[0] # (N, 3)
|
21 |
+
dist = np.linalg.norm(dist, axis=1) # (N, )
|
22 |
+
max_dist = dist.max()
|
23 |
+
return max_dist
|
24 |
+
|
25 |
+
|
26 |
+
def align_ate_c2b_use_a2b(traj_a, traj_b, traj_c=None, method='sim3'):
|
27 |
+
"""Align c to b using the sim3 from a to b.
|
28 |
+
:param traj_a: (N0, 3/4, 4) torch tensor
|
29 |
+
:param traj_b: (N0, 3/4, 4) torch tensor
|
30 |
+
:param traj_c: None or (N1, 3/4, 4) torch tensor
|
31 |
+
:return: (N1, 4, 4) torch tensor
|
32 |
+
"""
|
33 |
+
device = traj_a.device
|
34 |
+
if traj_c is None:
|
35 |
+
traj_c = traj_a.clone()
|
36 |
+
|
37 |
+
traj_a = traj_a.float().cpu().numpy()
|
38 |
+
traj_b = traj_b.float().cpu().numpy()
|
39 |
+
traj_c = traj_c.float().cpu().numpy()
|
40 |
+
|
41 |
+
R_a = traj_a[:, :3, :3] # (N0, 3, 3)
|
42 |
+
t_a = traj_a[:, :3, 3] # (N0, 3)
|
43 |
+
quat_a = SO3_to_quat(R_a) # (N0, 4)
|
44 |
+
|
45 |
+
R_b = traj_b[:, :3, :3] # (N0, 3, 3)
|
46 |
+
t_b = traj_b[:, :3, 3] # (N0, 3)
|
47 |
+
quat_b = SO3_to_quat(R_b) # (N0, 4)
|
48 |
+
|
49 |
+
# This function works in quaternion.
|
50 |
+
# scalar, (3, 3), (3, ) gt = R * s * est + t.
|
51 |
+
s, R, t = alignTrajectory(t_a, t_b, quat_a, quat_b, method=method)
|
52 |
+
|
53 |
+
# reshape tensors
|
54 |
+
R = R[None, :, :].astype(np.float32) # (1, 3, 3)
|
55 |
+
t = t[None, :, None].astype(np.float32) # (1, 3, 1)
|
56 |
+
s = float(s)
|
57 |
+
|
58 |
+
R_c = traj_c[:, :3, :3] # (N1, 3, 3)
|
59 |
+
t_c = traj_c[:, :3, 3:4] # (N1, 3, 1)
|
60 |
+
|
61 |
+
R_c_aligned = R @ R_c # (N1, 3, 3)
|
62 |
+
t_c_aligned = s * (R @ t_c) + t # (N1, 3, 1)
|
63 |
+
traj_c_aligned = np.concatenate([R_c_aligned, t_c_aligned], axis=2) # (N1, 3, 4)
|
64 |
+
|
65 |
+
# append the last row
|
66 |
+
traj_c_aligned = convert3x4_4x4(traj_c_aligned) # (N1, 4, 4)
|
67 |
+
|
68 |
+
traj_c_aligned = torch.from_numpy(traj_c_aligned).to(device)
|
69 |
+
return traj_c_aligned # (N1, 4, 4)
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
def align_scale_c2b_use_a2b(traj_a, traj_b, traj_c=None):
|
74 |
+
'''Scale c to b using the scale from a to b.
|
75 |
+
:param traj_a: (N0, 3/4, 4) torch tensor
|
76 |
+
:param traj_b: (N0, 3/4, 4) torch tensor
|
77 |
+
:param traj_c: None or (N1, 3/4, 4) torch tensor
|
78 |
+
:return:
|
79 |
+
scaled_traj_c (N1, 4, 4) torch tensor
|
80 |
+
scale scalar
|
81 |
+
'''
|
82 |
+
if traj_c is None:
|
83 |
+
traj_c = traj_a.clone()
|
84 |
+
|
85 |
+
t_a = traj_a[:, :3, 3] # (N, 3)
|
86 |
+
t_b = traj_b[:, :3, 3] # (N, 3)
|
87 |
+
|
88 |
+
# scale estimated poses to colmap scale
|
89 |
+
# s_a2b: a*s ~ b
|
90 |
+
scale_a2b = pts_dist_max(t_b) / pts_dist_max(t_a)
|
91 |
+
|
92 |
+
traj_c[:, :3, 3] *= scale_a2b
|
93 |
+
|
94 |
+
if traj_c.shape[1] == 3:
|
95 |
+
traj_c = convert3x4_4x4(traj_c) # (N, 4, 4)
|
96 |
+
|
97 |
+
return traj_c, scale_a2b # (N, 4, 4)
|
utils/utils_poses/comp_ate.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import utils.utils_poses.ATE.trajectory_utils as tu
|
5 |
+
import utils.utils_poses.ATE.transformations as tf
|
6 |
+
def rotation_error(pose_error):
|
7 |
+
"""Compute rotation error
|
8 |
+
Args:
|
9 |
+
pose_error (4x4 array): relative pose error
|
10 |
+
Returns:
|
11 |
+
rot_error (float): rotation error
|
12 |
+
"""
|
13 |
+
a = pose_error[0, 0]
|
14 |
+
b = pose_error[1, 1]
|
15 |
+
c = pose_error[2, 2]
|
16 |
+
d = 0.5*(a+b+c-1.0)
|
17 |
+
rot_error = np.arccos(max(min(d, 1.0), -1.0))
|
18 |
+
return rot_error
|
19 |
+
|
20 |
+
def translation_error(pose_error):
|
21 |
+
"""Compute translation error
|
22 |
+
Args:
|
23 |
+
pose_error (4x4 array): relative pose error
|
24 |
+
Returns:
|
25 |
+
trans_error (float): translation error
|
26 |
+
"""
|
27 |
+
dx = pose_error[0, 3]
|
28 |
+
dy = pose_error[1, 3]
|
29 |
+
dz = pose_error[2, 3]
|
30 |
+
trans_error = np.sqrt(dx**2+dy**2+dz**2)
|
31 |
+
return trans_error
|
32 |
+
|
33 |
+
def compute_rpe(gt, pred):
|
34 |
+
trans_errors = []
|
35 |
+
rot_errors = []
|
36 |
+
for i in range(len(gt)-1):
|
37 |
+
gt1 = gt[i]
|
38 |
+
gt2 = gt[i+1]
|
39 |
+
gt_rel = np.linalg.inv(gt1) @ gt2
|
40 |
+
|
41 |
+
pred1 = pred[i]
|
42 |
+
pred2 = pred[i+1]
|
43 |
+
pred_rel = np.linalg.inv(pred1) @ pred2
|
44 |
+
rel_err = np.linalg.inv(gt_rel) @ pred_rel
|
45 |
+
|
46 |
+
trans_errors.append(translation_error(rel_err))
|
47 |
+
rot_errors.append(rotation_error(rel_err))
|
48 |
+
rpe_trans = np.mean(np.asarray(trans_errors))
|
49 |
+
rpe_rot = np.mean(np.asarray(rot_errors))
|
50 |
+
return rpe_trans, rpe_rot
|
51 |
+
|
52 |
+
def compute_ATE(gt, pred):
|
53 |
+
"""Compute RMSE of ATE
|
54 |
+
Args:
|
55 |
+
gt: ground-truth poses
|
56 |
+
pred: predicted poses
|
57 |
+
"""
|
58 |
+
errors = []
|
59 |
+
|
60 |
+
for i in range(len(pred)):
|
61 |
+
# cur_gt = np.linalg.inv(gt_0) @ gt[i]
|
62 |
+
cur_gt = gt[i]
|
63 |
+
gt_xyz = cur_gt[:3, 3]
|
64 |
+
|
65 |
+
# cur_pred = np.linalg.inv(pred_0) @ pred[i]
|
66 |
+
cur_pred = pred[i]
|
67 |
+
pred_xyz = cur_pred[:3, 3]
|
68 |
+
|
69 |
+
align_err = gt_xyz - pred_xyz
|
70 |
+
|
71 |
+
errors.append(np.sqrt(np.sum(align_err ** 2)))
|
72 |
+
ate = np.sqrt(np.mean(np.asarray(errors) ** 2))
|
73 |
+
return ate
|
74 |
+
|
utils/utils_poses/lie_group_helper.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from scipy.spatial.transform import Rotation as RotLib
|
4 |
+
|
5 |
+
|
6 |
+
def SO3_to_quat(R):
|
7 |
+
"""
|
8 |
+
:param R: (N, 3, 3) or (3, 3) np
|
9 |
+
:return: (N, 4, ) or (4, ) np
|
10 |
+
"""
|
11 |
+
x = RotLib.from_matrix(R)
|
12 |
+
quat = x.as_quat()
|
13 |
+
return quat
|
14 |
+
|
15 |
+
|
16 |
+
def quat_to_SO3(quat):
|
17 |
+
"""
|
18 |
+
:param quat: (N, 4, ) or (4, ) np
|
19 |
+
:return: (N, 3, 3) or (3, 3) np
|
20 |
+
"""
|
21 |
+
x = RotLib.from_quat(quat)
|
22 |
+
R = x.as_matrix()
|
23 |
+
return R
|
24 |
+
|
25 |
+
|
26 |
+
def convert3x4_4x4(input):
|
27 |
+
"""
|
28 |
+
:param input: (N, 3, 4) or (3, 4) torch or np
|
29 |
+
:return: (N, 4, 4) or (4, 4) torch or np
|
30 |
+
"""
|
31 |
+
if torch.is_tensor(input):
|
32 |
+
if len(input.shape) == 3:
|
33 |
+
output = torch.cat([input, torch.zeros_like(input[:, 0:1])], dim=1) # (N, 4, 4)
|
34 |
+
output[:, 3, 3] = 1.0
|
35 |
+
else:
|
36 |
+
output = torch.cat([input, torch.tensor([[0,0,0,1]], dtype=input.dtype, device=input.device)], dim=0) # (4, 4)
|
37 |
+
else:
|
38 |
+
if len(input.shape) == 3:
|
39 |
+
output = np.concatenate([input, np.zeros_like(input[:, 0:1])], axis=1) # (N, 4, 4)
|
40 |
+
output[:, 3, 3] = 1.0
|
41 |
+
else:
|
42 |
+
output = np.concatenate([input, np.array([[0,0,0,1]], dtype=input.dtype)], axis=0) # (4, 4)
|
43 |
+
output[3, 3] = 1.0
|
44 |
+
return output
|
45 |
+
|
46 |
+
|
47 |
+
def vec2skew(v):
|
48 |
+
"""
|
49 |
+
:param v: (3, ) torch tensor
|
50 |
+
:return: (3, 3)
|
51 |
+
"""
|
52 |
+
zero = torch.zeros(1, dtype=torch.float32, device=v.device)
|
53 |
+
skew_v0 = torch.cat([ zero, -v[2:3], v[1:2]]) # (3, 1)
|
54 |
+
skew_v1 = torch.cat([ v[2:3], zero, -v[0:1]])
|
55 |
+
skew_v2 = torch.cat([-v[1:2], v[0:1], zero])
|
56 |
+
skew_v = torch.stack([skew_v0, skew_v1, skew_v2], dim=0) # (3, 3)
|
57 |
+
return skew_v # (3, 3)
|
58 |
+
|
59 |
+
|
60 |
+
def Exp(r):
|
61 |
+
"""so(3) vector to SO(3) matrix
|
62 |
+
:param r: (3, ) axis-angle, torch tensor
|
63 |
+
:return: (3, 3)
|
64 |
+
"""
|
65 |
+
skew_r = vec2skew(r) # (3, 3)
|
66 |
+
norm_r = r.norm() + 1e-15
|
67 |
+
eye = torch.eye(3, dtype=torch.float32, device=r.device)
|
68 |
+
R = eye + (torch.sin(norm_r) / norm_r) * skew_r + ((1 - torch.cos(norm_r)) / norm_r**2) * (skew_r @ skew_r)
|
69 |
+
return R
|
70 |
+
|
71 |
+
|
72 |
+
def make_c2w(r, t):
|
73 |
+
"""
|
74 |
+
:param r: (3, ) axis-angle torch tensor
|
75 |
+
:param t: (3, ) translation vector torch tensor
|
76 |
+
:return: (4, 4)
|
77 |
+
"""
|
78 |
+
R = Exp(r) # (3, 3)
|
79 |
+
c2w = torch.cat([R, t.unsqueeze(1)], dim=1) # (3, 4)
|
80 |
+
c2w = convert3x4_4x4(c2w) # (4, 4)
|
81 |
+
return c2w
|
utils/utils_poses/relative_pose.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def compute_relative_world_to_camera(R1, t1, R2, t2):
|
6 |
+
zero_row = torch.tensor([[0, 0, 0, 1]], dtype=torch.float32, device="cuda") #, requires_grad=True
|
7 |
+
E1_inv = torch.cat([torch.transpose(R1, 0, 1), -torch.transpose(R1, 0, 1) @ t1.reshape(-1, 1)], dim=1)
|
8 |
+
E1_inv = torch.cat([E1_inv, zero_row], dim=0)
|
9 |
+
E2 = torch.cat([R2, -R2 @ t2.reshape(-1, 1)], dim=1)
|
10 |
+
E2 = torch.cat([E2, zero_row], dim=0)
|
11 |
+
|
12 |
+
# Compute relative transformation
|
13 |
+
E_rel = E2 @ E1_inv
|
14 |
+
|
15 |
+
# # Extract rotation and translation
|
16 |
+
# R_rel = E_rel[:3, :3]
|
17 |
+
# t_rel = E_rel[:3, 3]
|
18 |
+
# E_rel = torch.cat([E_rel, zero_row], dim=0)
|
19 |
+
|
20 |
+
return E_rel
|
utils/utils_poses/vis_cam_traj.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file is modified from NeRF++: https://github.com/Kai-46/nerfplusplus
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
try:
|
6 |
+
import open3d as o3d
|
7 |
+
except ImportError:
|
8 |
+
pass
|
9 |
+
|
10 |
+
|
11 |
+
def frustums2lineset(frustums):
|
12 |
+
N = len(frustums)
|
13 |
+
merged_points = np.zeros((N*5, 3)) # 5 vertices per frustum
|
14 |
+
merged_lines = np.zeros((N*8, 2)) # 8 lines per frustum
|
15 |
+
merged_colors = np.zeros((N*8, 3)) # each line gets a color
|
16 |
+
|
17 |
+
for i, (frustum_points, frustum_lines, frustum_colors) in enumerate(frustums):
|
18 |
+
merged_points[i*5:(i+1)*5, :] = frustum_points
|
19 |
+
merged_lines[i*8:(i+1)*8, :] = frustum_lines + i*5
|
20 |
+
merged_colors[i*8:(i+1)*8, :] = frustum_colors
|
21 |
+
|
22 |
+
lineset = o3d.geometry.LineSet()
|
23 |
+
lineset.points = o3d.utility.Vector3dVector(merged_points)
|
24 |
+
lineset.lines = o3d.utility.Vector2iVector(merged_lines)
|
25 |
+
lineset.colors = o3d.utility.Vector3dVector(merged_colors)
|
26 |
+
|
27 |
+
return lineset
|
28 |
+
|
29 |
+
|
30 |
+
def get_camera_frustum_opengl_coord(H, W, fx, fy, W2C, frustum_length=0.5, color=np.array([0., 1., 0.])):
|
31 |
+
'''X right, Y up, Z backward to the observer.
|
32 |
+
:param H, W:
|
33 |
+
:param fx, fy:
|
34 |
+
:param W2C: (4, 4) matrix
|
35 |
+
:param frustum_length: scalar: scale the frustum
|
36 |
+
:param color: (3,) list, frustum line color
|
37 |
+
:return:
|
38 |
+
frustum_points: (5, 3) frustum points in world coordinate
|
39 |
+
frustum_lines: (8, 2) 8 lines connect 5 frustum points, specified in line start/end index.
|
40 |
+
frustum_colors: (8, 3) colors for 8 lines.
|
41 |
+
'''
|
42 |
+
hfov = np.rad2deg(np.arctan(W / 2. / fx) * 2.)
|
43 |
+
vfov = np.rad2deg(np.arctan(H / 2. / fy) * 2.)
|
44 |
+
half_w = frustum_length * np.tan(np.deg2rad(hfov / 2.))
|
45 |
+
half_h = frustum_length * np.tan(np.deg2rad(vfov / 2.))
|
46 |
+
|
47 |
+
# build view frustum in camera space in homogenous coordinate (5, 4)
|
48 |
+
frustum_points = np.array([[0., 0., 0., 1.0], # frustum origin
|
49 |
+
[-half_w, half_h, -frustum_length, 1.0], # top-left image corner
|
50 |
+
[half_w, half_h, -frustum_length, 1.0], # top-right image corner
|
51 |
+
[half_w, -half_h, -frustum_length, 1.0], # bottom-right image corner
|
52 |
+
[-half_w, -half_h, -frustum_length, 1.0]]) # bottom-left image corner
|
53 |
+
frustum_lines = np.array([[0, i] for i in range(1, 5)] + [[i, (i+1)] for i in range(1, 4)] + [[4, 1]]) # (8, 2)
|
54 |
+
frustum_colors = np.tile(color.reshape((1, 3)), (frustum_lines.shape[0], 1)) # (8, 3)
|
55 |
+
|
56 |
+
# transform view frustum from camera space to world space
|
57 |
+
C2W = np.linalg.inv(W2C)
|
58 |
+
frustum_points = np.matmul(C2W, frustum_points.T).T # (5, 4)
|
59 |
+
frustum_points = frustum_points[:, :3] / frustum_points[:, 3:4] # (5, 3) remove homogenous coordinate
|
60 |
+
return frustum_points, frustum_lines, frustum_colors
|
61 |
+
|
62 |
+
def get_camera_frustum_opencv_coord(H, W, fx, fy, W2C, frustum_length=0.5, color=np.array([0., 1., 0.])):
|
63 |
+
'''X right, Y up, Z backward to the observer.
|
64 |
+
:param H, W:
|
65 |
+
:param fx, fy:
|
66 |
+
:param W2C: (4, 4) matrix
|
67 |
+
:param frustum_length: scalar: scale the frustum
|
68 |
+
:param color: (3,) list, frustum line color
|
69 |
+
:return:
|
70 |
+
frustum_points: (5, 3) frustum points in world coordinate
|
71 |
+
frustum_lines: (8, 2) 8 lines connect 5 frustum points, specified in line start/end index.
|
72 |
+
frustum_colors: (8, 3) colors for 8 lines.
|
73 |
+
'''
|
74 |
+
hfov = np.rad2deg(np.arctan(W / 2. / fx) * 2.)
|
75 |
+
vfov = np.rad2deg(np.arctan(H / 2. / fy) * 2.)
|
76 |
+
half_w = frustum_length * np.tan(np.deg2rad(hfov / 2.))
|
77 |
+
half_h = frustum_length * np.tan(np.deg2rad(vfov / 2.))
|
78 |
+
|
79 |
+
# build view frustum in camera space in homogenous coordinate (5, 4)
|
80 |
+
frustum_points = np.array([[0., 0., 0., 1.0], # frustum origin
|
81 |
+
[-half_w, -half_h, frustum_length, 1.0], # top-left image corner
|
82 |
+
[ half_w, -half_h, frustum_length, 1.0], # top-right image corner
|
83 |
+
[ half_w, half_h, frustum_length, 1.0], # bottom-right image corner
|
84 |
+
[-half_w, +half_h, frustum_length, 1.0]]) # bottom-left image corner
|
85 |
+
frustum_lines = np.array([[0, i] for i in range(1, 5)] + [[i, (i+1)] for i in range(1, 4)] + [[4, 1]]) # (8, 2)
|
86 |
+
frustum_colors = np.tile(color.reshape((1, 3)), (frustum_lines.shape[0], 1)) # (8, 3)
|
87 |
+
|
88 |
+
# transform view frustum from camera space to world space
|
89 |
+
C2W = np.linalg.inv(W2C)
|
90 |
+
frustum_points = np.matmul(C2W, frustum_points.T).T # (5, 4)
|
91 |
+
frustum_points = frustum_points[:, :3] / frustum_points[:, 3:4] # (5, 3) remove homogenous coordinate
|
92 |
+
return frustum_points, frustum_lines, frustum_colors
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
def draw_camera_frustum_geometry(c2ws, H, W, fx=600.0, fy=600.0, frustum_length=0.5,
|
97 |
+
color=np.array([29.0, 53.0, 87.0])/255.0, draw_now=False, coord='opengl'):
|
98 |
+
'''
|
99 |
+
:param c2ws: (N, 4, 4) np.array
|
100 |
+
:param H: scalar
|
101 |
+
:param W: scalar
|
102 |
+
:param fx: scalar
|
103 |
+
:param fy: scalar
|
104 |
+
:param frustum_length: scalar
|
105 |
+
:param color: None or (N, 3) or (3, ) or (1, 3) or (3, 1) np array
|
106 |
+
:param draw_now: True/False call o3d vis now
|
107 |
+
:return:
|
108 |
+
'''
|
109 |
+
N = c2ws.shape[0]
|
110 |
+
|
111 |
+
num_ele = color.flatten().shape[0]
|
112 |
+
if num_ele == 3:
|
113 |
+
color = color.reshape(1, 3)
|
114 |
+
color = np.tile(color, (N, 1))
|
115 |
+
|
116 |
+
frustum_list = []
|
117 |
+
if coord == 'opengl':
|
118 |
+
for i in range(N):
|
119 |
+
frustum_list.append(get_camera_frustum_opengl_coord(H, W, fx, fy,
|
120 |
+
W2C=np.linalg.inv(c2ws[i]),
|
121 |
+
frustum_length=frustum_length,
|
122 |
+
color=color[i]))
|
123 |
+
elif coord == 'opencv':
|
124 |
+
for i in range(N):
|
125 |
+
frustum_list.append(get_camera_frustum_opencv_coord(H, W, fx, fy,
|
126 |
+
W2C=np.linalg.inv(c2ws[i]),
|
127 |
+
frustum_length=frustum_length,
|
128 |
+
color=color[i]))
|
129 |
+
else:
|
130 |
+
print('Undefined coordinate system. Exit')
|
131 |
+
exit()
|
132 |
+
|
133 |
+
frustums_geometry = frustums2lineset(frustum_list)
|
134 |
+
|
135 |
+
if draw_now:
|
136 |
+
o3d.visualization.draw_geometries([frustums_geometry])
|
137 |
+
|
138 |
+
return frustums_geometry # this is an o3d geometry object.
|
utils/utils_poses/vis_pose_utils.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import matplotlib
|
3 |
+
matplotlib.use('Agg')
|
4 |
+
|
5 |
+
from matplotlib import pyplot as plt
|
6 |
+
plt.ioff()
|
7 |
+
|
8 |
+
import copy
|
9 |
+
from evo.core.trajectory import PosePath3D, PoseTrajectory3D
|
10 |
+
from evo.main_ape import ape
|
11 |
+
from evo.tools import plot
|
12 |
+
from evo.core import sync
|
13 |
+
from evo.tools import file_interface
|
14 |
+
from evo.core import metrics
|
15 |
+
import evo
|
16 |
+
import torch
|
17 |
+
import numpy as np
|
18 |
+
from scipy.spatial.transform import Slerp
|
19 |
+
from scipy.spatial.transform import Rotation as R
|
20 |
+
import scipy.interpolate as si
|
21 |
+
|
22 |
+
|
23 |
+
def interp_poses(c2ws, N_views):
|
24 |
+
N_inputs = c2ws.shape[0]
|
25 |
+
trans = c2ws[:, :3, 3:].permute(2, 1, 0)
|
26 |
+
rots = c2ws[:, :3, :3]
|
27 |
+
render_poses = []
|
28 |
+
rots = R.from_matrix(rots)
|
29 |
+
slerp = Slerp(np.linspace(0, 1, N_inputs), rots)
|
30 |
+
interp_rots = torch.tensor(
|
31 |
+
slerp(np.linspace(0, 1, N_views)).as_matrix().astype(np.float32))
|
32 |
+
interp_trans = torch.nn.functional.interpolate(
|
33 |
+
trans, size=N_views, mode='linear').permute(2, 1, 0)
|
34 |
+
render_poses = torch.cat([interp_rots, interp_trans], dim=2)
|
35 |
+
render_poses = convert3x4_4x4(render_poses)
|
36 |
+
return render_poses
|
37 |
+
|
38 |
+
|
39 |
+
def interp_poses_bspline(c2ws, N_novel_imgs, input_times, degree):
|
40 |
+
target_trans = torch.tensor(scipy_bspline(
|
41 |
+
c2ws[:, :3, 3], n=N_novel_imgs, degree=degree, periodic=False).astype(np.float32)).unsqueeze(2)
|
42 |
+
rots = R.from_matrix(c2ws[:, :3, :3])
|
43 |
+
slerp = Slerp(input_times, rots)
|
44 |
+
target_times = np.linspace(input_times[0], input_times[-1], N_novel_imgs)
|
45 |
+
target_rots = torch.tensor(
|
46 |
+
slerp(target_times).as_matrix().astype(np.float32))
|
47 |
+
target_poses = torch.cat([target_rots, target_trans], dim=2)
|
48 |
+
target_poses = convert3x4_4x4(target_poses)
|
49 |
+
return target_poses
|
50 |
+
|
51 |
+
|
52 |
+
def poses_avg(poses):
|
53 |
+
|
54 |
+
hwf = poses[0, :3, -1:]
|
55 |
+
|
56 |
+
center = poses[:, :3, 3].mean(0)
|
57 |
+
vec2 = normalize(poses[:, :3, 2].sum(0))
|
58 |
+
up = poses[:, :3, 1].sum(0)
|
59 |
+
c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)
|
60 |
+
|
61 |
+
return c2w
|
62 |
+
|
63 |
+
|
64 |
+
def normalize(v):
|
65 |
+
"""Normalize a vector."""
|
66 |
+
return v / np.linalg.norm(v)
|
67 |
+
|
68 |
+
|
69 |
+
def viewmatrix(z, up, pos):
|
70 |
+
vec2 = normalize(z)
|
71 |
+
vec1_avg = up
|
72 |
+
vec0 = normalize(np.cross(vec1_avg, vec2))
|
73 |
+
vec1 = normalize(np.cross(vec2, vec0))
|
74 |
+
m = np.stack([vec0, vec1, vec2, pos], 1)
|
75 |
+
return m
|
76 |
+
|
77 |
+
|
78 |
+
def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
|
79 |
+
render_poses = []
|
80 |
+
rads = np.array(list(rads) + [1.])
|
81 |
+
hwf = c2w[:, 4:5]
|
82 |
+
|
83 |
+
for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]:
|
84 |
+
# c = np.dot(c2w[:3,:4], np.array([0.7*np.cos(theta) , -0.3*np.sin(theta) , -np.sin(theta*zrate) *0.1, 1.]) * rads)
|
85 |
+
# c = np.dot(c2w[:3,:4], np.array([0.3*np.cos(theta) , -0.3*np.sin(theta) , -np.sin(theta*zrate) *0.01, 1.]) * rads)
|
86 |
+
c = np.dot(c2w[:3, :4], np.array(
|
87 |
+
[0.2*np.cos(theta), -0.2*np.sin(theta), -np.sin(theta*zrate) * 0.1, 1.]) * rads)
|
88 |
+
z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))
|
89 |
+
render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1))
|
90 |
+
return render_poses
|
91 |
+
|
92 |
+
|
93 |
+
def scipy_bspline(cv, n=100, degree=3, periodic=False):
|
94 |
+
""" Calculate n samples on a bspline
|
95 |
+
|
96 |
+
cv : Array ov control vertices
|
97 |
+
n : Number of samples to return
|
98 |
+
degree: Curve degree
|
99 |
+
periodic: True - Curve is closed
|
100 |
+
"""
|
101 |
+
cv = np.asarray(cv)
|
102 |
+
count = cv.shape[0]
|
103 |
+
|
104 |
+
# Closed curve
|
105 |
+
if periodic:
|
106 |
+
kv = np.arange(-degree, count+degree+1)
|
107 |
+
factor, fraction = divmod(count+degree+1, count)
|
108 |
+
cv = np.roll(np.concatenate(
|
109 |
+
(cv,) * factor + (cv[:fraction],)), -1, axis=0)
|
110 |
+
degree = np.clip(degree, 1, degree)
|
111 |
+
|
112 |
+
# Opened curve
|
113 |
+
else:
|
114 |
+
degree = np.clip(degree, 1, count-1)
|
115 |
+
kv = np.clip(np.arange(count+degree+1)-degree, 0, count-degree)
|
116 |
+
|
117 |
+
# Return samples
|
118 |
+
max_param = count - (degree * (1-periodic))
|
119 |
+
spl = si.BSpline(kv, cv, degree)
|
120 |
+
return spl(np.linspace(0, max_param, n))
|
121 |
+
|
122 |
+
|
123 |
+
def generate_spiral_nerf(learned_poses, bds, N_novel_views, hwf):
|
124 |
+
learned_poses_ = np.concatenate((learned_poses[:, :3, :4].detach(
|
125 |
+
).cpu().numpy(), hwf[:len(learned_poses)]), axis=-1)
|
126 |
+
c2w = poses_avg(learned_poses_)
|
127 |
+
print('recentered', c2w.shape)
|
128 |
+
# Get spiral
|
129 |
+
# Get average pose
|
130 |
+
up = normalize(learned_poses_[:, :3, 1].sum(0))
|
131 |
+
# Find a reasonable "focus depth" for this dataset
|
132 |
+
|
133 |
+
close_depth, inf_depth = bds.min()*.9, bds.max()*5.
|
134 |
+
dt = .75
|
135 |
+
mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth))
|
136 |
+
focal = mean_dz
|
137 |
+
|
138 |
+
# Get radii for spiral path
|
139 |
+
shrink_factor = .8
|
140 |
+
zdelta = close_depth * .2
|
141 |
+
tt = learned_poses_[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T
|
142 |
+
rads = np.percentile(np.abs(tt), 90, 0)
|
143 |
+
c2w_path = c2w
|
144 |
+
N_rots = 2
|
145 |
+
c2ws = render_path_spiral(
|
146 |
+
c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_novel_views)
|
147 |
+
c2ws = torch.tensor(np.stack(c2ws).astype(np.float32))
|
148 |
+
c2ws = c2ws[:, :3, :4]
|
149 |
+
c2ws = convert3x4_4x4(c2ws)
|
150 |
+
return c2ws
|
151 |
+
|
152 |
+
|
153 |
+
def convert3x4_4x4(input):
|
154 |
+
"""
|
155 |
+
:param input: (N, 3, 4) or (3, 4) torch or np
|
156 |
+
:return: (N, 4, 4) or (4, 4) torch or np
|
157 |
+
"""
|
158 |
+
if torch.is_tensor(input):
|
159 |
+
if len(input.shape) == 3:
|
160 |
+
output = torch.cat([input, torch.zeros_like(
|
161 |
+
input[:, 0:1])], dim=1) # (N, 4, 4)
|
162 |
+
output[:, 3, 3] = 1.0
|
163 |
+
else:
|
164 |
+
output = torch.cat([input, torch.tensor(
|
165 |
+
[[0, 0, 0, 1]], dtype=input.dtype, device=input.device)], dim=0) # (4, 4)
|
166 |
+
else:
|
167 |
+
if len(input.shape) == 3:
|
168 |
+
output = np.concatenate(
|
169 |
+
[input, np.zeros_like(input[:, 0:1])], axis=1) # (N, 4, 4)
|
170 |
+
output[:, 3, 3] = 1.0
|
171 |
+
else:
|
172 |
+
output = np.concatenate(
|
173 |
+
[input, np.array([[0, 0, 0, 1]], dtype=input.dtype)], axis=0) # (4, 4)
|
174 |
+
output[3, 3] = 1.0
|
175 |
+
return output
|
176 |
+
|
177 |
+
|
178 |
+
plt.rc('legend', fontsize=20) # using a named size
|
179 |
+
|
180 |
+
|
181 |
+
def plot_pose(ref_poses, est_poses, output_path, args, vid=False):
|
182 |
+
ref_poses = [pose for pose in ref_poses]
|
183 |
+
if isinstance(est_poses, dict):
|
184 |
+
est_poses = [pose for k, pose in est_poses.items()]
|
185 |
+
else:
|
186 |
+
est_poses = [pose for pose in est_poses]
|
187 |
+
traj_ref = PosePath3D(poses_se3=ref_poses)
|
188 |
+
traj_est = PosePath3D(poses_se3=est_poses)
|
189 |
+
traj_est_aligned = copy.deepcopy(traj_est)
|
190 |
+
traj_est_aligned.align(traj_ref, correct_scale=True,
|
191 |
+
correct_only_scale=False)
|
192 |
+
if vid:
|
193 |
+
for p_idx in range(len(ref_poses)):
|
194 |
+
fig = plt.figure()
|
195 |
+
current_est_aligned = traj_est_aligned.poses_se3[:p_idx+1]
|
196 |
+
current_ref = traj_ref.poses_se3[:p_idx+1]
|
197 |
+
current_est_aligned = PosePath3D(poses_se3=current_est_aligned)
|
198 |
+
current_ref = PosePath3D(poses_se3=current_ref)
|
199 |
+
traj_by_label = {
|
200 |
+
# "estimate (not aligned)": traj_est,
|
201 |
+
"Ours (aligned)": current_est_aligned,
|
202 |
+
"Ground-truth": current_ref
|
203 |
+
}
|
204 |
+
plot_mode = plot.PlotMode.xyz
|
205 |
+
# ax = plot.prepare_axis(fig, plot_mode, 111)
|
206 |
+
ax = fig.add_subplot(111, projection="3d")
|
207 |
+
ax.xaxis.set_tick_params(labelbottom=False)
|
208 |
+
ax.yaxis.set_tick_params(labelleft=False)
|
209 |
+
ax.zaxis.set_tick_params(labelleft=False)
|
210 |
+
colors = ['r', 'b']
|
211 |
+
styles = ['-', '--']
|
212 |
+
|
213 |
+
for idx, (label, traj) in enumerate(traj_by_label.items()):
|
214 |
+
plot.traj(ax, plot_mode, traj,
|
215 |
+
styles[idx], colors[idx], label)
|
216 |
+
# break
|
217 |
+
# plot.trajectories(fig, traj_by_label, plot.PlotMode.xyz)
|
218 |
+
ax.view_init(elev=10., azim=45)
|
219 |
+
plt.tight_layout()
|
220 |
+
os.makedirs(os.path.join(os.path.dirname(
|
221 |
+
output_path), 'pose_vid'), exist_ok=True)
|
222 |
+
pose_vis_path = os.path.join(os.path.dirname(
|
223 |
+
output_path), 'pose_vid', 'pose_vis_{:03d}.png'.format(p_idx))
|
224 |
+
print(pose_vis_path)
|
225 |
+
fig.savefig(pose_vis_path)
|
226 |
+
|
227 |
+
# else:
|
228 |
+
|
229 |
+
fig = plt.figure()
|
230 |
+
fig.patch.set_facecolor('white') # 把背景设置为纯白色
|
231 |
+
traj_by_label = {
|
232 |
+
# "estimate (not aligned)": traj_est,
|
233 |
+
|
234 |
+
"Ours (aligned)": traj_est_aligned,
|
235 |
+
# "NoPe-NeRF (aligned)": traj_est_aligned,
|
236 |
+
# "CF-3DGS (aligned)": traj_est_aligned,
|
237 |
+
# "NeRFmm (aligned)": traj_est_aligned,
|
238 |
+
# args.method + " (aligned)": traj_est_aligned,
|
239 |
+
"COLMAP (GT)": traj_ref
|
240 |
+
# "Ground-truth": traj_ref
|
241 |
+
}
|
242 |
+
plot_mode = plot.PlotMode.xyz
|
243 |
+
# ax = plot.prepare_axis(fig, plot_mode, 111)
|
244 |
+
ax = fig.add_subplot(111, projection="3d")
|
245 |
+
ax.set_facecolor('white') # 把子图设置为纯白色
|
246 |
+
ax.xaxis.set_tick_params(labelbottom=True)
|
247 |
+
ax.yaxis.set_tick_params(labelleft=True)
|
248 |
+
ax.zaxis.set_tick_params(labelleft=True)
|
249 |
+
colors = ['#2c9e38', '#d12920'] #
|
250 |
+
# colors = ['#2c9e38', '#a72126'] #
|
251 |
+
|
252 |
+
# colors = ['r', 'b']
|
253 |
+
styles = ['-', '--']
|
254 |
+
|
255 |
+
for idx, (label, traj) in enumerate(traj_by_label.items()):
|
256 |
+
plot.traj(ax, plot_mode, traj,
|
257 |
+
styles[idx], colors[idx], label)
|
258 |
+
# break
|
259 |
+
# plot.trajectories(fig, traj_by_label, plot.PlotMode.xyz)
|
260 |
+
ax.view_init(elev=30., azim=45)
|
261 |
+
# ax.view_init(elev=10., azim=45)
|
262 |
+
plt.tight_layout()
|
263 |
+
pose_vis_path = output_path / f'pose_vis.png'
|
264 |
+
# pose_vis_path = os.path.join(os.path.dirname(output_path), f'pose_vis_{args.method}_{args.scene}.png')
|
265 |
+
fig.savefig(pose_vis_path)
|
266 |
+
|
267 |
+
# path_parts = args.pose_path.split('/')
|
268 |
+
# tmp_vis_path = '/'.join(path_parts[:-1]) + '/all_vis'
|
269 |
+
# tmp_vis_path2 = os.path.join(tmp_vis_path, f'pose_vis_{args.method}_{args.scene}.png')
|
270 |
+
# fig.savefig(tmp_vis_path2)
|