faneggg commited on
Commit
123719b
·
1 Parent(s): d89af41
.gitignore ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /.idea/
2
+ /work_dirs*
3
+ .vscode/
4
+ /tmp
5
+ /data
6
+ # /checkpoints
7
+ *.so
8
+ *.patch
9
+ __pycache__/
10
+ *.egg-info/
11
+ /viz*
12
+ /submit*
13
+ build/
14
+ *.pyd
15
+ /cache*
16
+ *.stl
17
+ # *.pth
18
+ /venv/
19
+ .nk8s
20
+ *.mp4
21
+ .vs
22
+ /exp/
23
+ /dev/
README.md CHANGED
@@ -1,10 +1,11 @@
1
  ---
2
  title: Feat2GS
3
- emoji: 📈
4
- colorFrom: red
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 5.16.0
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
1
  ---
2
  title: Feat2GS
3
+ emoji:
4
+ colorFrom: yellow
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.20.1
8
+ python_version: 3.10.13
9
  app_file: app.py
10
  pinned: false
11
  license: apache-2.0
arguments/__init__.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ from argparse import ArgumentParser, Namespace
13
+ import sys
14
+ import os
15
+
16
+ class GroupParams:
17
+ pass
18
+
19
+ class ParamGroup:
20
+ def __init__(self, parser: ArgumentParser, name : str, fill_none = False):
21
+ group = parser.add_argument_group(name)
22
+ for key, value in vars(self).items():
23
+ shorthand = False
24
+ if key.startswith("_"):
25
+ shorthand = True
26
+ key = key[1:]
27
+ t = type(value)
28
+ value = value if not fill_none else None
29
+ if shorthand:
30
+ if t == bool:
31
+ group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true")
32
+ else:
33
+ group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t)
34
+ else:
35
+ if t == bool:
36
+ group.add_argument("--" + key, default=value, action="store_true")
37
+ else:
38
+ group.add_argument("--" + key, default=value, type=t)
39
+
40
+ def extract(self, args):
41
+ group = GroupParams()
42
+ for arg in vars(args).items():
43
+ if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
44
+ setattr(group, arg[0], arg[1])
45
+ return group
46
+
47
+ class ModelParams(ParamGroup):
48
+ def __init__(self, parser, sentinel=False):
49
+ self.sh_degree = 3
50
+ self._source_path = ""
51
+ self._model_path = ""
52
+ self._images = "images"
53
+ self._resolution = -1
54
+ self._white_background = False
55
+ self.data_device = "cuda"
56
+ self.eval = False
57
+ self.feat_default_dim = {
58
+ 'iuv': 3,
59
+ 'iuvrgb': 6,
60
+ 'mast3r': 1024,
61
+ 'dust3r': 1024,
62
+ 'dift': 1280,
63
+ 'dino_b16': 768,
64
+ 'dinov2_b14': 768,
65
+ 'radio': 1280,
66
+ 'clip_b16': 512,
67
+ 'mae_b16': 768,
68
+ 'midas_l16': 1024,
69
+ 'sam_base': 768,
70
+ # 'dino16': 384,
71
+ # 'dinov2': 384,
72
+ # 'clip': 512,
73
+ # 'maskclip': 512,
74
+ # 'vit': 384,
75
+ # 'resnet50': 2048,
76
+ # 'midas': 768,
77
+ # 'mae': 1024,
78
+ }
79
+ self.gs_params_group = {
80
+ 'G':{
81
+ 'head': ['xyz', 'scaling', 'rotation', 'opacity'],
82
+ 'opt':['f_dc', 'f_rest']
83
+ },
84
+ 'T':{
85
+ 'head': ['f_dc', 'f_rest'],
86
+ 'opt':['xyz', 'scaling', 'rotation', 'opacity']
87
+ },
88
+ 'A':{
89
+ 'head': ['xyz', 'scaling', 'rotation', 'opacity', 'f_dc', 'f_rest'],
90
+ 'opt':[]
91
+ },
92
+ 'Gft':{
93
+ 'head': ['xyz', 'scaling', 'rotation', 'opacity'],
94
+ 'opt':['f_dc', 'f_rest', 'pc_feat']
95
+ },
96
+ 'Tft':{
97
+ 'head': ['f_dc', 'f_rest'],
98
+ 'opt':['xyz', 'scaling', 'rotation', 'opacity', 'pc_feat']
99
+ },
100
+ 'Aft':{
101
+ 'head': ['xyz', 'scaling', 'rotation', 'opacity', 'f_dc', 'f_rest'],
102
+ 'opt':['pc_feat']
103
+ },
104
+ }
105
+ super().__init__(parser, "Loading Parameters", sentinel)
106
+
107
+ def extract(self, args):
108
+ g = super().extract(args)
109
+ g.source_path = os.path.abspath(g.source_path)
110
+ return g
111
+
112
+ class PipelineParams(ParamGroup):
113
+ def __init__(self, parser):
114
+ self.convert_SHs_python = False
115
+ self.compute_cov3D_python = False
116
+ self.debug = False
117
+ super().__init__(parser, "Pipeline Parameters")
118
+
119
+ class DefualtOptimizationParams(ParamGroup):
120
+ def __init__(self, parser):
121
+ self.lr_multiplier = 1.
122
+ self.iterations = 30_000
123
+ self.position_lr_init = 0.00016 * self.lr_multiplier
124
+ self.position_lr_final = 0.0000016 * self.lr_multiplier
125
+ self.position_lr_delay_mult = 0.01
126
+ self.position_lr_max_steps = 30_000
127
+ self.feature_lr = 0.0025 * self.lr_multiplier
128
+ self.opacity_lr = 0.05 * self.lr_multiplier
129
+ self.scaling_lr = 0.005 * self.lr_multiplier
130
+ self.rotation_lr = 0.001 * self.lr_multiplier
131
+ self.percent_dense = 0.01
132
+ self.lambda_dssim = 0.2
133
+ self.densification_interval = 100
134
+ self.opacity_reset_interval = 3000
135
+ self.densify_from_iter = 500
136
+ self.densify_until_iter = 15_000
137
+ self.densify_grad_threshold = 0.0002
138
+ self.random_background = False
139
+ super().__init__(parser, "Optimization Parameters")
140
+
141
+
142
+ class OptimizationParams(ParamGroup):
143
+ def __init__(self, parser):
144
+ self.lr_multiplier = 0.1
145
+ self.iterations = 30_000
146
+ self.pose_lr_init = 0.0001 #0.0001
147
+ self.pose_lr_final = 0.000001 #0.0001
148
+ self.position_lr_init = 0.00016 * self.lr_multiplier #0.000001
149
+ self.position_lr_final = 0.0000016 * self.lr_multiplier #0.000001
150
+ self.position_lr_delay_mult = 0.01
151
+ self.position_lr_max_steps = 30_000
152
+ self.feature_lr = 0.0025 * self.lr_multiplier #0.001
153
+ self.feature_sh_lr = (0.0025/20.) * self.lr_multiplier #0.000001
154
+ self.opacity_lr = 0.05 * self.lr_multiplier #0.0001
155
+ self.scaling_lr = 0.005 * self.lr_multiplier # 0.001
156
+ self.rotation_lr = 0.001 * self.lr_multiplier # 0.00001
157
+ self.percent_dense = 0.01
158
+ self.lambda_dssim = 0.2
159
+ self.densification_interval = 100
160
+ self.opacity_reset_interval = 3000
161
+ # self.densify_from_iter = 500
162
+ # self.densify_until_iter = 15_000
163
+ # self.densify_grad_threshold = 0.0002
164
+ self.random_background = False
165
+ super().__init__(parser, "Optimization Parameters")
166
+
167
+ def get_combined_args(parser : ArgumentParser):
168
+ cmdlne_string = sys.argv[1:]
169
+ cfgfile_string = "Namespace()"
170
+ args_cmdline = parser.parse_args(cmdlne_string)
171
+
172
+ try:
173
+ cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
174
+ print("Looking for config file in", cfgfilepath)
175
+ with open(cfgfilepath) as cfg_file:
176
+ print("Config file found: {}".format(cfgfilepath))
177
+ cfgfile_string = cfg_file.read()
178
+ except TypeError:
179
+ print("Config file not found at")
180
+ pass
181
+ args_cfgfile = eval(cfgfile_string)
182
+
183
+ merged_dict = vars(args_cfgfile).copy()
184
+ for k,v in vars(args_cmdline).items():
185
+ if v != None:
186
+ merged_dict[k] = v
187
+ return Namespace(**merged_dict)
command ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ conda activate feat2gs
2
+ cd Feat2GS/
3
+
4
+ bash scripts/run_feat2gs_eval_parallel.sh
5
+ bash scripts/run_feat2gs_eval.sh
6
+ bash scripts/run_instantsplat_eval_parallel.sh
7
+ bash scripts/run_feat2gs_eval_dtu_parallel.sh
8
+
9
+ python video/generate_video.py
10
+
11
+ bash scripts/run_all_trajectories.sh
12
+ bash scripts/run_video_render.sh
13
+ bash scripts/run_video_render_instantsplat.sh
14
+ bash scripts/run_video_render_dtu.sh
15
+
16
+ tensorboard --logdir=/home/chenyue/output/Feat2gs/output/eval/ --port=7001
17
+
18
+ cd /home/chenyue/output/Feat2gs/output/eval/Tanks/Train/6_views/feat2gs-G/dust3r/
19
+ tensorboard --logdir_spec \
20
+ radio:radio,\
21
+ dust3r:dust3r,\
22
+ dino_b16:dino_b16,\
23
+ mast3r:mast3r,\
24
+ dift:dift,\
25
+ dinov2:dinov2_b14,\
26
+ clip:clip_b16,\
27
+ mae:mae_b16,\
28
+ midas:midas_l16,\
29
+ sam:sam_base,\
30
+ iuvrgb:iuvrgb \
31
+ --port 7002
32
+
33
+ CUDA_VISIBLE_DEVICES=7 gradio demo.py
gaussian_renderer/__init__.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import torch
13
+ import math
14
+ from scene.gaussian_model import GaussianModel
15
+ from utils.pose_utils import get_camera_from_tensor, quadmultiply
16
+ from utils.graphics_utils import depth_to_normal
17
+
18
+
19
+ ### if use [diff-gaussian-rasterization](https://github.com/graphdeco-inria/diff-gaussian-rasterization)
20
+
21
+ # from diff_gaussian_rasterization import (
22
+ # GaussianRasterizationSettings,
23
+ # GaussianRasterizer,
24
+ # )
25
+ # from utils.sh_utils import eval_sh
26
+
27
+ # def render(
28
+ # viewpoint_camera,
29
+ # pc: GaussianModel,
30
+ # pipe,
31
+ # bg_color: torch.Tensor,
32
+ # scaling_modifier=1.0,
33
+ # override_color=None,
34
+ # camera_pose=None,
35
+ # ):
36
+ # """
37
+ # Render the scene.
38
+
39
+ # Background tensor (bg_color) must be on GPU!
40
+ # """
41
+
42
+ # # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
43
+ # screenspace_points = (
44
+ # torch.zeros_like(
45
+ # pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda"
46
+ # )
47
+ # + 0
48
+ # )
49
+ # try:
50
+ # screenspace_points.retain_grad()
51
+ # except:
52
+ # pass
53
+
54
+ # # Set up rasterization configuration
55
+ # tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
56
+ # tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
57
+
58
+ # # Set camera pose as identity. Then, we will transform the Gaussians around camera_pose
59
+ # w2c = torch.eye(4).cuda()
60
+ # projmatrix = (
61
+ # w2c.unsqueeze(0).bmm(viewpoint_camera.projection_matrix.unsqueeze(0))
62
+ # ).squeeze(0)
63
+ # camera_pos = w2c.inverse()[3, :3]
64
+ # raster_settings = GaussianRasterizationSettings(
65
+ # image_height=int(viewpoint_camera.image_height),
66
+ # image_width=int(viewpoint_camera.image_width),
67
+ # tanfovx=tanfovx,
68
+ # tanfovy=tanfovy,
69
+ # bg=bg_color,
70
+ # scale_modifier=scaling_modifier,
71
+ # # viewmatrix=viewpoint_camera.world_view_transform,
72
+ # # projmatrix=viewpoint_camera.full_proj_transform,
73
+ # viewmatrix=w2c,
74
+ # projmatrix=projmatrix,
75
+ # sh_degree=pc.active_sh_degree,
76
+ # # campos=viewpoint_camera.camera_center,
77
+ # campos=camera_pos,
78
+ # prefiltered=False,
79
+ # debug=pipe.debug,
80
+ # )
81
+
82
+ # rasterizer = GaussianRasterizer(raster_settings=raster_settings)
83
+
84
+ # # means3D = pc.get_xyz
85
+ # rel_w2c = get_camera_from_tensor(camera_pose)
86
+ # # Transform mean and rot of Gaussians to camera frame
87
+ # gaussians_xyz = pc._xyz.clone()
88
+ # gaussians_rot = pc._rotation.clone()
89
+
90
+ # xyz_ones = torch.ones(gaussians_xyz.shape[0], 1).cuda().float()
91
+ # xyz_homo = torch.cat((gaussians_xyz, xyz_ones), dim=1)
92
+ # gaussians_xyz_trans = (rel_w2c @ xyz_homo.T).T[:, :3]
93
+ # gaussians_rot_trans = quadmultiply(camera_pose[:4], gaussians_rot)
94
+ # means3D = gaussians_xyz_trans
95
+ # means2D = screenspace_points
96
+ # opacity = pc.get_opacity
97
+
98
+ # # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
99
+ # # scaling / rotation by the rasterizer.
100
+ # scales = None
101
+ # rotations = None
102
+ # cov3D_precomp = None
103
+ # if pipe.compute_cov3D_python:
104
+ # cov3D_precomp = pc.get_covariance(scaling_modifier)
105
+ # else:
106
+ # scales = pc.get_scaling
107
+ # rotations = gaussians_rot_trans # pc.get_rotation
108
+
109
+ # # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
110
+ # # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
111
+ # shs = None
112
+ # colors_precomp = None
113
+ # if override_color is None:
114
+ # if pipe.convert_SHs_python:
115
+ # shs_view = pc.get_features.transpose(1, 2).view(
116
+ # -1, 3, (pc.max_sh_degree + 1) ** 2
117
+ # )
118
+ # dir_pp = pc.get_xyz - viewpoint_camera.camera_center.repeat(
119
+ # pc.get_features.shape[0], 1
120
+ # )
121
+ # dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
122
+ # sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
123
+ # colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
124
+ # else:
125
+ # shs = pc.get_features
126
+ # else:
127
+ # colors_precomp = override_color
128
+
129
+ # # Rasterize visible Gaussians to image, obtain their radii (on screen).
130
+ # rendered_image, radii = rasterizer(
131
+ # means3D=means3D,
132
+ # means2D=means2D,
133
+ # shs=shs,
134
+ # colors_precomp=colors_precomp,
135
+ # opacities=opacity,
136
+ # scales=scales,
137
+ # rotations=rotations,
138
+ # cov3D_precomp=cov3D_precomp,
139
+ # )
140
+
141
+ # # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
142
+ # # They will be excluded from value updates used in the splitting criteria.
143
+ # return {
144
+ # "render": rendered_image,
145
+ # "viewspace_points": screenspace_points,
146
+ # "visibility_filter": radii > 0,
147
+ # "radii": radii,
148
+ # }
149
+
150
+
151
+ ### if use [gsplat](https://github.com/nerfstudio-project/gsplat)
152
+
153
+ from gsplat import rasterization
154
+
155
+ def render_gsplat(
156
+ viewpoint_camera,
157
+ pc : GaussianModel,
158
+ pipe,
159
+ bg_color : torch.Tensor,
160
+ scaling_modifier = 1.0,
161
+ override_color = None,
162
+ camera_pose = None,
163
+ fov = None,
164
+ render_mode="RGB"):
165
+ """
166
+ Render the scene.
167
+
168
+ Background tensor (bg_color) must be on GPU!
169
+ """
170
+ if fov is None:
171
+ FoVx = viewpoint_camera.FoVx
172
+ FoVy = viewpoint_camera.FoVy
173
+ else:
174
+ FoVx = fov[0]
175
+ FoVy = fov[1]
176
+ tanfovx = math.tan(FoVx * 0.5)
177
+ tanfovy = math.tan(FoVy * 0.5)
178
+ focal_length_x = viewpoint_camera.image_width / (2 * tanfovx)
179
+ focal_length_y = viewpoint_camera.image_height / (2 * tanfovy)
180
+ K = torch.tensor(
181
+ [
182
+ [focal_length_x, 0, viewpoint_camera.image_width / 2.0],
183
+ [0, focal_length_y, viewpoint_camera.image_height / 2.0],
184
+ [0, 0, 1],
185
+ ],
186
+ device="cuda",
187
+ )
188
+
189
+ means3D = pc.get_xyz
190
+ opacity = pc.get_opacity
191
+ scales = pc.get_scaling * scaling_modifier
192
+ rotations = pc.get_rotation
193
+ if override_color is not None:
194
+ colors = override_color # [N, 3]
195
+ sh_degree = None
196
+ else:
197
+ colors = pc.get_features # [N, K, 3]
198
+ sh_degree = pc.active_sh_degree
199
+
200
+ if camera_pose is None:
201
+ viewmat = viewpoint_camera.world_view_transform.transpose(0, 1) # [4, 4]
202
+ else:
203
+ viewmat = get_camera_from_tensor(camera_pose)
204
+ render_colors, render_alphas, info = rasterization(
205
+ means=means3D, # [N, 3]
206
+ quats=rotations, # [N, 4]
207
+ scales=scales, # [N, 3]
208
+ opacities=opacity.squeeze(-1), # [N,]
209
+ colors=colors,
210
+ viewmats=viewmat[None], # [1, 4, 4]
211
+ Ks=K[None], # [1, 3, 3]
212
+ backgrounds=bg_color[None],
213
+ width=int(viewpoint_camera.image_width),
214
+ height=int(viewpoint_camera.image_height),
215
+ packed=False,
216
+ sh_degree=sh_degree,
217
+ render_mode=render_mode,
218
+ )
219
+
220
+ if "D" in render_mode:
221
+ if "+" in render_mode:
222
+ depth_map = render_colors[..., -1:]
223
+ else:
224
+ depth_map = render_colors
225
+
226
+ normals_surf = depth_to_normal(
227
+ depth_map, torch.inverse(viewmat[None]), K[None])
228
+ normals_surf = normals_surf * (render_alphas).detach()
229
+ render_colors = torch.cat([render_colors, normals_surf], dim=-1)
230
+
231
+ # [1, H, W, 3] -> [3, H, W]
232
+ rendered_image = render_colors[0].permute(2, 0, 1)
233
+
234
+ radii = info["radii"].squeeze(0) # [N,]
235
+ try:
236
+ info["means2d"].retain_grad() # [1, N, 2]
237
+ except:
238
+ pass
239
+
240
+ # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
241
+ # They will be excluded from value updates used in the splitting criteria.
242
+ return {"render": rendered_image,
243
+ "viewspace_points": info["means2d"],
244
+ "visibility_filter" : radii > 0,
245
+ "radii": radii}
gaussian_renderer/network_gui.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import torch
13
+ import traceback
14
+ import socket
15
+ import json
16
+ from scene.cameras import MiniCam
17
+
18
+ host = "127.0.0.1"
19
+ port = 6009
20
+
21
+ conn = None
22
+ addr = None
23
+
24
+ listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
25
+
26
+ def init(wish_host, wish_port):
27
+ global host, port, listener
28
+ host = wish_host
29
+ port = wish_port
30
+ listener.bind((host, port))
31
+ listener.listen()
32
+ listener.settimeout(0)
33
+
34
+ def try_connect():
35
+ global conn, addr, listener
36
+ try:
37
+ conn, addr = listener.accept()
38
+ print(f"\nConnected by {addr}")
39
+ conn.settimeout(None)
40
+ except Exception as inst:
41
+ pass
42
+
43
+ def read():
44
+ global conn
45
+ messageLength = conn.recv(4)
46
+ messageLength = int.from_bytes(messageLength, 'little')
47
+ message = conn.recv(messageLength)
48
+ return json.loads(message.decode("utf-8"))
49
+
50
+ def send(message_bytes, verify):
51
+ global conn
52
+ if message_bytes != None:
53
+ conn.sendall(message_bytes)
54
+ conn.sendall(len(verify).to_bytes(4, 'little'))
55
+ conn.sendall(bytes(verify, 'ascii'))
56
+
57
+ def receive():
58
+ message = read()
59
+
60
+ width = message["resolution_x"]
61
+ height = message["resolution_y"]
62
+
63
+ if width != 0 and height != 0:
64
+ try:
65
+ do_training = bool(message["train"])
66
+ fovy = message["fov_y"]
67
+ fovx = message["fov_x"]
68
+ znear = message["z_near"]
69
+ zfar = message["z_far"]
70
+ do_shs_python = bool(message["shs_python"])
71
+ do_rot_scale_python = bool(message["rot_scale_python"])
72
+ keep_alive = bool(message["keep_alive"])
73
+ scaling_modifier = message["scaling_modifier"]
74
+ world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda()
75
+ world_view_transform[:,1] = -world_view_transform[:,1]
76
+ world_view_transform[:,2] = -world_view_transform[:,2]
77
+ full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda()
78
+ full_proj_transform[:,1] = -full_proj_transform[:,1]
79
+ custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform)
80
+ except Exception as e:
81
+ print("")
82
+ traceback.print_exc()
83
+ raise e
84
+ return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier
85
+ else:
86
+ return None, None, None, None, None, None
lpipsPyTorch/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .modules.lpips import LPIPS
4
+
5
+
6
+ def lpips(x: torch.Tensor,
7
+ y: torch.Tensor,
8
+ net_type: str = 'alex',
9
+ version: str = '0.1',
10
+ return_spatial_map=False):
11
+ r"""Function that measures
12
+ Learned Perceptual Image Patch Similarity (LPIPS).
13
+
14
+ Arguments:
15
+ x, y (torch.Tensor): the input tensors to compare.
16
+ net_type (str): the network type to compare the features:
17
+ 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
18
+ version (str): the version of LPIPS. Default: 0.1.
19
+ return_spatial_map (bool): whether to return the spatial map. Default: False.
20
+ """
21
+ device = x.device
22
+ criterion = LPIPS(net_type, version).to(device)
23
+ return criterion(x, y, return_spatial_map)
lpipsPyTorch/modules/lpips.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .networks import get_network, LinLayers
5
+ from .utils import get_state_dict
6
+
7
+
8
+ class LPIPS(nn.Module):
9
+ r"""Creates a criterion that measures
10
+ Learned Perceptual Image Patch Similarity (LPIPS).
11
+
12
+ Arguments:
13
+ net_type (str): the network type to compare the features:
14
+ 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
15
+ version (str): the version of LPIPS. Default: 0.1.
16
+ """
17
+ def __init__(self, net_type: str = 'alex', version: str = '0.1'):
18
+
19
+ assert version in ['0.1'], 'v0.1 is only supported now'
20
+
21
+ super(LPIPS, self).__init__()
22
+
23
+ # pretrained network
24
+ self.net = get_network(net_type)
25
+
26
+ # linear layers
27
+ self.lin = LinLayers(self.net.n_channels_list)
28
+ self.lin.load_state_dict(get_state_dict(net_type, version))
29
+
30
+ def forward(self, x: torch.Tensor, y: torch.Tensor, return_spatial_map=False):
31
+ feat_x, feat_y = self.net(x), self.net(y)
32
+
33
+ diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
34
+ res = [l(d) for d, l in zip(diff, self.lin)]
35
+
36
+ if return_spatial_map:
37
+ target_size = (x.shape[2], x.shape[3])
38
+ res_upsampled = [torch.nn.functional.interpolate(r, size=target_size, mode='bilinear', align_corners=False)
39
+ for r in res]
40
+ spatial_map = torch.sum(torch.cat(res_upsampled, 1), 1, keepdim=True)
41
+ return spatial_map
42
+ else:
43
+ res = [r.mean((2, 3), True) for r in res]
44
+ return torch.sum(torch.cat(res, 0), 0, True)
lpipsPyTorch/modules/networks.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+
3
+ from itertools import chain
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import models
8
+
9
+ from .utils import normalize_activation
10
+
11
+
12
+ def get_network(net_type: str):
13
+ if net_type == 'alex':
14
+ return AlexNet()
15
+ elif net_type == 'squeeze':
16
+ return SqueezeNet()
17
+ elif net_type == 'vgg':
18
+ return VGG16()
19
+ else:
20
+ raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21
+
22
+
23
+ class LinLayers(nn.ModuleList):
24
+ def __init__(self, n_channels_list: Sequence[int]):
25
+ super(LinLayers, self).__init__([
26
+ nn.Sequential(
27
+ nn.Identity(),
28
+ nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29
+ ) for nc in n_channels_list
30
+ ])
31
+
32
+ for param in self.parameters():
33
+ param.requires_grad = False
34
+
35
+
36
+ class BaseNet(nn.Module):
37
+ def __init__(self):
38
+ super(BaseNet, self).__init__()
39
+
40
+ # register buffer
41
+ self.register_buffer(
42
+ 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43
+ self.register_buffer(
44
+ 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45
+
46
+ def set_requires_grad(self, state: bool):
47
+ for param in chain(self.parameters(), self.buffers()):
48
+ param.requires_grad = state
49
+
50
+ def z_score(self, x: torch.Tensor):
51
+ return (x - self.mean) / self.std
52
+
53
+ def forward(self, x: torch.Tensor):
54
+ x = self.z_score(x)
55
+
56
+ output = []
57
+ for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58
+ x = layer(x)
59
+ if i in self.target_layers:
60
+ output.append(normalize_activation(x))
61
+ if len(output) == len(self.target_layers):
62
+ break
63
+ return output
64
+
65
+
66
+ class SqueezeNet(BaseNet):
67
+ def __init__(self):
68
+ super(SqueezeNet, self).__init__()
69
+
70
+ self.layers = models.squeezenet1_1(True).features
71
+ self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72
+ self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73
+
74
+ self.set_requires_grad(False)
75
+
76
+
77
+ class AlexNet(BaseNet):
78
+ def __init__(self):
79
+ super(AlexNet, self).__init__()
80
+
81
+ self.layers = models.alexnet(True).features
82
+ self.target_layers = [2, 5, 8, 10, 12]
83
+ self.n_channels_list = [64, 192, 384, 256, 256]
84
+
85
+ self.set_requires_grad(False)
86
+
87
+
88
+ class VGG16(BaseNet):
89
+ def __init__(self):
90
+ super(VGG16, self).__init__()
91
+
92
+ self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
93
+ self.target_layers = [4, 9, 16, 23, 30]
94
+ self.n_channels_list = [64, 128, 256, 512, 512]
95
+
96
+ self.set_requires_grad(False)
lpipsPyTorch/modules/utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+
5
+
6
+ def normalize_activation(x, eps=1e-10):
7
+ norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
8
+ return x / (norm_factor + eps)
9
+
10
+
11
+ def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12
+ # build url
13
+ url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
14
+ + f'master/lpips/weights/v{version}/{net_type}.pth'
15
+
16
+ # download
17
+ old_state_dict = torch.hub.load_state_dict_from_url(
18
+ url, progress=True,
19
+ map_location=None if torch.cuda.is_available() else torch.device('cpu')
20
+ )
21
+
22
+ # rename keys
23
+ new_state_dict = OrderedDict()
24
+ for key, val in old_state_dict.items():
25
+ new_key = key
26
+ new_key = new_key.replace('lin', '')
27
+ new_key = new_key.replace('model.', '')
28
+ new_state_dict[new_key] = val
29
+
30
+ return new_state_dict
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.5.1
2
+ torchvision==0.20.1
3
+ roma
4
+ evo
5
+ gradio>=4,<5
6
+ matplotlib
7
+ tqdm
8
+ opencv-python
9
+ scipy
10
+ einops
11
+ trimesh
12
+ tensorboard
13
+ pyglet<2
14
+ imageio
15
+ gsplat
16
+ scikit-learn
17
+ hydra-submitit-launcher
18
+ huggingface-hub[torch]==0.24
19
+ plyfile
20
+ imageio[ffmpeg]
21
+ spaces
run_video.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import matplotlib
13
+ matplotlib.use('Agg')
14
+
15
+ import math
16
+ import copy
17
+ import torch
18
+ from scene import Scene
19
+ import os
20
+ from tqdm import tqdm
21
+ from gaussian_renderer import render_gsplat
22
+ from argparse import ArgumentParser
23
+ from arguments import ModelParams, PipelineParams, get_combined_args
24
+ from gaussian_renderer import GaussianModel
25
+ from utils.pose_utils import get_tensor_from_camera
26
+ import numpy as np
27
+ import imageio.v3 as iio
28
+ from utils.graphics_utils import resize_render, make_video_divisble
29
+
30
+ from utils.trajectories import (
31
+ get_arc_w2cs,
32
+ get_avg_w2c,
33
+ get_lemniscate_w2cs,
34
+ get_spiral_w2cs,
35
+ get_wander_w2cs,
36
+ get_lookat,
37
+ )
38
+
39
+ from utils.camera_utils import generate_interpolated_path, generate_ellipse_path
40
+ from utils.camera_traj_config import trajectory_configs
41
+
42
+
43
+ def save_interpolated_pose(model_path, iter, n_views):
44
+
45
+ org_pose = np.load(model_path + f"pose/pose_{iter}.npy")
46
+ # visualizer(org_pose, ["green" for _ in org_pose], model_path + "pose/poses_optimized.png")
47
+ n_interp = int(10 * 30 / n_views) # 10second, fps=30
48
+ all_inter_pose = []
49
+ for i in range(n_views-1):
50
+ tmp_inter_pose = generate_interpolated_path(poses=org_pose[i:i+2], n_interp=n_interp)
51
+ all_inter_pose.append(tmp_inter_pose)
52
+ all_inter_pose = np.array(all_inter_pose).reshape(-1, 3, 4)
53
+
54
+ inter_pose_list = []
55
+ for p in all_inter_pose:
56
+ tmp_view = np.eye(4)
57
+ tmp_view[:3, :3] = p[:3, :3]
58
+ tmp_view[:3, 3] = p[:3, 3]
59
+ inter_pose_list.append(tmp_view)
60
+ inter_pose = np.stack(inter_pose_list, 0)
61
+ return inter_pose
62
+
63
+
64
+ def save_ellipse_pose(model_path, iter, n_views):
65
+
66
+ org_pose = np.load(model_path + f"pose/pose_{iter}.npy")
67
+ # visualizer(org_pose, ["green" for _ in org_pose], model_path + "pose/poses_optimized.png")
68
+ n_interp = int(10 * 30 / n_views) * (n_views-1) # 10second, fps=30
69
+ all_inter_pose = generate_ellipse_path(org_pose, n_interp)
70
+
71
+ inter_pose_list = []
72
+ for p in all_inter_pose:
73
+ c2w = np.eye(4)
74
+ c2w[:3, :4] = p
75
+ inter_pose_list.append(np.linalg.inv(c2w))
76
+ inter_pose = np.stack(inter_pose_list, 0)
77
+
78
+ return inter_pose
79
+
80
+ def save_traj_pose(dataset, iter, args):
81
+
82
+ traj_up = trajectory_configs.get(args.dataset, {}).get(args.scene, {}).get('up', [-1, 1]) # Use -y axis in camera space as up vector
83
+ traj_params = trajectory_configs.get(args.dataset, {}).get(args.scene, {}).get(args.cam_traj, {})
84
+
85
+ # 1. Get training camera poses and calculate trajectory
86
+ org_pose = np.load(dataset.model_path + f"pose/pose_{iter}.npy")
87
+ train_w2cs = torch.from_numpy(org_pose).cuda()
88
+
89
+ # Calculate reference camera pose
90
+ avg_w2c = get_avg_w2c(train_w2cs)
91
+ train_c2ws = torch.linalg.inv(train_w2cs)
92
+ lookat = get_lookat(train_c2ws[:, :3, -1], train_c2ws[:, :3, 2])
93
+ # up = torch.tensor([0.0, 0.0, 1.0], device="cuda")
94
+ avg_c2w = torch.linalg.inv(avg_w2c)
95
+ up = traj_up[0] * (avg_c2w[:3, traj_up[1]])
96
+ # up = traj_up[0] * (avg_c2w[:3, 0]+avg_c2w[:3, 1])/2
97
+
98
+ # Temporarily load a camera to get intrinsic parameters
99
+ tmp_args = copy.deepcopy(args)
100
+ tmp_args.get_video = False
101
+ tmp_dataset = copy.deepcopy(dataset)
102
+ tmp_dataset.eval = False
103
+ with torch.no_grad():
104
+ temp_gaussians = GaussianModel(dataset.sh_degree)
105
+ temp_scene = Scene(tmp_dataset, temp_gaussians, load_iteration=iter, opt=tmp_args, shuffle=False)
106
+
107
+ view = temp_scene.getTrainCameras()[0]
108
+ tanfovx = math.tan(view.FoVx * 0.5)
109
+ tanfovy = math.tan(view.FoVy * 0.5)
110
+ focal_length_x = view.image_width / (2 * tanfovx)
111
+ focal_length_y = view.image_height / (2 * tanfovy)
112
+
113
+ K = torch.tensor([[focal_length_x, 0, view.image_width/2],
114
+ [0, focal_length_y, view.image_height/2],
115
+ [0, 0, 1]], device="cuda")
116
+ img_wh = (view.image_width, view.image_height)
117
+
118
+ del temp_scene # Release temporary scene
119
+ del temp_gaussians # Release temporary gaussians
120
+
121
+ # Calculate bounding sphere radius
122
+ rc_train_c2ws = torch.einsum("ij,njk->nik", torch.linalg.inv(avg_w2c), train_c2ws)
123
+ rc_pos = rc_train_c2ws[:, :3, -1]
124
+ rads = (rc_pos.amax(0) - rc_pos.amin(0)) * 1.25
125
+
126
+ num_frames = int(10 * 30 / args.n_views) * (args.n_views-1)
127
+
128
+ # Generate camera poses based on trajectory type
129
+ if args.cam_traj == 'arc':
130
+ w2cs = get_arc_w2cs(
131
+ ref_w2c=avg_w2c,
132
+ lookat=lookat,
133
+ up=up,
134
+ focal_length=K[0, 0].item(),
135
+ rads=rads,
136
+ num_frames=num_frames,
137
+ degree=traj_params.get('degree', 180.0)
138
+ )
139
+ elif args.cam_traj == 'spiral':
140
+ w2cs = get_spiral_w2cs(
141
+ ref_w2c=avg_w2c,
142
+ lookat=lookat,
143
+ up=up,
144
+ focal_length=K[0, 0].item(),
145
+ rads=rads,
146
+ num_frames=num_frames,
147
+ zrate=traj_params.get('zrate', 0.5),
148
+ rots=traj_params.get('rots', 1)
149
+ )
150
+ elif args.cam_traj == 'lemniscate':
151
+ w2cs = get_lemniscate_w2cs(
152
+ ref_w2c=avg_w2c,
153
+ lookat=lookat,
154
+ up=up,
155
+ focal_length=K[0, 0].item(),
156
+ rads=rads,
157
+ num_frames=num_frames,
158
+ degree=traj_params.get('degree', 45.0)
159
+ )
160
+ elif args.cam_traj == 'wander':
161
+ w2cs = get_wander_w2cs(
162
+ ref_w2c=avg_w2c,
163
+ focal_length=K[0, 0].item(),
164
+ num_frames=num_frames,
165
+ max_disp=traj_params.get('max_disp', 48.0)
166
+ )
167
+ else:
168
+ raise ValueError(f"Unknown camera trajectory: {args.cam_traj}")
169
+
170
+ return w2cs.cpu().numpy()
171
+
172
+ def render_sets(dataset: ModelParams, iteration: int, pipeline: PipelineParams, args):
173
+ if args.cam_traj in ['interpolated', 'ellipse']:
174
+ w2cs = globals().get(f'save_{args.cam_traj}_pose')(dataset.model_path, iteration, args.n_views)
175
+ else:
176
+ w2cs = save_traj_pose(dataset, iteration, args)
177
+
178
+ # visualizer(org_pose, ["green" for _ in org_pose], dataset.model_path + f"pose/poses_optimized.png")
179
+ # visualizer(w2cs, ["blue" for _ in w2cs], dataset.model_path + f"pose/poses_{args.cam_traj}.png")
180
+ np.save(dataset.model_path + f"pose/pose_{args.cam_traj}.npy", w2cs)
181
+
182
+ # 2. Load model and scene
183
+ with torch.no_grad():
184
+ gaussians = GaussianModel(dataset.sh_degree)
185
+ scene = Scene(dataset, gaussians, load_iteration=iteration, opt=args, shuffle=False)
186
+
187
+ # bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
188
+ bg_color = [1, 1, 1]
189
+ background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
190
+
191
+ # 3. Rendering
192
+ # render_path = os.path.join(dataset.model_path, args.cam_traj, f"ours_{iteration}", "renders")
193
+ # if os.path.exists(render_path):
194
+ # shutil.rmtree(render_path)
195
+ # makedirs(render_path, exist_ok=True)
196
+
197
+ video = []
198
+ for idx, w2c in enumerate(tqdm(w2cs, desc="Rendering progress")):
199
+ camera_pose = get_tensor_from_camera(w2c.transpose(0, 1))
200
+ view = scene.getTrainCameras()[0] # Use parameters from the first camera as template
201
+ if args.resize:
202
+ view = resize_render(view)
203
+
204
+ rendering = render_gsplat(
205
+ view, gaussians, pipeline, background, camera_pose=camera_pose
206
+ )["render"]
207
+
208
+ # # Save single frame image
209
+ # torchvision.utils.save_image(
210
+ # rendering, os.path.join(render_path, "{0:05d}".format(idx) + ".png")
211
+ # )
212
+
213
+ # Add to video list
214
+ # img = (rendering.detach().cpu().numpy() * 255.0).astype(np.uint8)
215
+ img = (torch.clamp(rendering, 0, 1).detach().cpu().numpy() * 255.0).round().astype(np.uint8)
216
+ video.append(img)
217
+
218
+ video = np.stack(video, 0).transpose(0, 2, 3, 1)
219
+
220
+ # Save video
221
+ if args.get_video:
222
+ video_dir = os.path.join(dataset.model_path, 'videos')
223
+ os.makedirs(video_dir, exist_ok=True)
224
+ output_video_file = os.path.join(video_dir, f'{args.scene}_{args.n_views}_view_{args.cam_traj}.mp4')
225
+ # iio.imwrite(output_video_file, make_video_divisble(video), fps=30)
226
+ iio.imwrite(
227
+ output_video_file,
228
+ make_video_divisble(video),
229
+ fps=30,
230
+ codec='libx264',
231
+ quality=None,
232
+ output_params=[
233
+ '-crf', '28', # Good quality range between 18-28
234
+ '-preset', 'veryslow',
235
+ '-pix_fmt', 'yuv420p',
236
+ '-movflags', '+faststart'
237
+ ]
238
+ )
239
+
240
+ # if args.get_video:
241
+ # image_folder = os.path.join(dataset.model_path, f'{args.cam_traj}/ours_{args.iteration}/renders')
242
+ # output_video_file = os.path.join(dataset.model_path, f'{args.scene}_{args.n_views}_view_{args.cam_traj}.mp4')
243
+ # images_to_video(image_folder, output_video_file, fps=30)
244
+
245
+
246
+ if __name__ == "__main__":
247
+ # Set up command line argument parser
248
+ parser = ArgumentParser(description="Testing script parameters")
249
+ model = ModelParams(parser, sentinel=True)
250
+ pipeline = PipelineParams(parser)
251
+ parser.add_argument("--iteration", default=-1, type=int)
252
+ parser.add_argument("--quiet", action="store_true")
253
+ parser.add_argument("--get_video", action="store_true")
254
+ parser.add_argument("--n_views", default=120, type=int)
255
+ parser.add_argument("--dataset", default=None, type=str)
256
+ parser.add_argument("--scene", default=None, type=str)
257
+ parser.add_argument("--cam_traj", default='arc', type=str,
258
+ choices=['arc', 'spiral', 'lemniscate', 'wander', 'interpolated', 'ellipse'],
259
+ help="Camera trajectory type")
260
+ parser.add_argument("--resize", action="store_true", default=True,
261
+ help="If True, resize rendering to square")
262
+ parser.add_argument("--feat_type", type=str, nargs='*', default=None,
263
+ help="Feature type(s). Multiple types can be specified for combination.")
264
+ parser.add_argument("--method", type=str, default='dust3r',
265
+ help="Method of Initialization, e.g., 'dust3r' or 'mast3r'")
266
+
267
+ args = get_combined_args(parser)
268
+ print("Rendering " + args.model_path)
269
+
270
+ render_sets(
271
+ model.extract(args),
272
+ args.iteration,
273
+ pipeline.extract(args),
274
+ args,
275
+ )
scene/__init__.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import os
13
+ import random
14
+ import json
15
+ from utils.system_utils import searchForMaxIteration
16
+ from scene.dataset_readers import sceneLoadTypeCallbacks
17
+ from scene.gaussian_model import GaussianModel, Feat2GaussianModel
18
+ from arguments import ModelParams
19
+ from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
20
+
21
+ class Scene:
22
+
23
+ gaussian : GaussianModel
24
+
25
+ def __init__(self, args : ModelParams, gaussian : GaussianModel, load_iteration=None, opt=None, shuffle=True, resolution_scales=[1.0]):
26
+ """b
27
+ :param path: Path to colmap scene main folder.
28
+ """
29
+ self.model_path = args.model_path
30
+ self.loaded_iter = None
31
+ self.gaussians = gaussian
32
+
33
+ if load_iteration:
34
+ if load_iteration == -1:
35
+ self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
36
+ else:
37
+ self.loaded_iter = load_iteration
38
+ print("Loading trained model at iteration {}".format(self.loaded_iter))
39
+
40
+ self.train_cameras = {}
41
+ self.test_cameras = {}
42
+ # self.render_cameras = {}
43
+
44
+ if os.path.exists(os.path.join(args.source_path, "sparse")):
45
+ scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval, args, opt)
46
+ elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
47
+ print("Found transforms_train.json file, assuming Blender data set!")
48
+ scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
49
+ else:
50
+ assert False, "Could not recognize scene type!"
51
+
52
+ if not self.loaded_iter:
53
+ with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
54
+ dest_file.write(src_file.read())
55
+ json_cams = []
56
+ camlist = []
57
+ if scene_info.test_cameras:
58
+ camlist.extend(scene_info.test_cameras)
59
+ if scene_info.train_cameras:
60
+ camlist.extend(scene_info.train_cameras)
61
+ # if scene_info.render_cameras:
62
+ # camlist.extend(scene_info.render_cameras)
63
+ for id, cam in enumerate(camlist):
64
+ json_cams.append(camera_to_JSON(id, cam))
65
+ with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
66
+ json.dump(json_cams, file)
67
+
68
+ if shuffle:
69
+ random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
70
+ random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
71
+
72
+ self.cameras_extent = scene_info.nerf_normalization["radius"]
73
+
74
+ for resolution_scale in resolution_scales:
75
+ print("Loading Training Cameras")
76
+ self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
77
+ print('train_camera_num: ', len(self.train_cameras[resolution_scale]))
78
+ print("Loading Test Cameras")
79
+ self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
80
+ print('test_camera_num: ', len(self.test_cameras[resolution_scale]))
81
+ # print("Loading Render Cameras")
82
+ # self.render_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.render_cameras, resolution_scale, args)
83
+ # print('render_camera_num: ', len(self.render_cameras[resolution_scale]))
84
+
85
+ if self.loaded_iter:
86
+ self.gaussians.load_ply(os.path.join(self.model_path,
87
+ "point_cloud",
88
+ "iteration_" + str(self.loaded_iter),
89
+ "point_cloud.ply"))
90
+ else:
91
+ self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
92
+ self.gaussians.init_RT_seq(self.train_cameras)
93
+
94
+ def save(self, iteration):
95
+ point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
96
+ self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
97
+
98
+ def getTrainCameras(self, scale=1.0):
99
+ return self.train_cameras[scale]
100
+
101
+ def getTestCameras(self, scale=1.0):
102
+ return self.test_cameras[scale]
103
+
104
+ # def getRenderCameras(self, scale=1.0):
105
+ # return self.render_cameras[scale]
scene/cameras.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import torch
13
+ from torch import nn
14
+ import numpy as np
15
+ from utils.graphics_utils import getWorld2View2, getProjectionMatrix
16
+
17
+ class Camera(nn.Module):
18
+ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
19
+ image_name, uid,
20
+ trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
21
+ ):
22
+ super(Camera, self).__init__()
23
+
24
+ self.uid = uid
25
+ self.colmap_id = colmap_id
26
+ self.R = R
27
+ self.T = T
28
+ self.FoVx = FoVx
29
+ self.FoVy = FoVy
30
+ self.image_name = image_name
31
+
32
+ try:
33
+ self.data_device = torch.device(data_device)
34
+ except Exception as e:
35
+ print(e)
36
+ print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
37
+ self.data_device = torch.device("cuda")
38
+
39
+ self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
40
+ self.image_width = self.original_image.shape[2]
41
+ self.image_height = self.original_image.shape[1]
42
+
43
+ if gt_alpha_mask is not None:
44
+ self.original_image *= gt_alpha_mask.to(self.data_device)
45
+ else:
46
+ self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
47
+
48
+ self.zfar = 100.0
49
+ self.znear = 0.01
50
+
51
+ self.trans = trans
52
+ self.scale = scale
53
+
54
+ self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
55
+ self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
56
+ self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
57
+ self.camera_center = self.world_view_transform.inverse()[3, :3]
58
+
59
+ class MiniCam:
60
+ def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
61
+ self.image_width = width
62
+ self.image_height = height
63
+ self.FoVy = fovy
64
+ self.FoVx = fovx
65
+ self.znear = znear
66
+ self.zfar = zfar
67
+ self.world_view_transform = world_view_transform
68
+ self.full_proj_transform = full_proj_transform
69
+ view_inv = torch.inverse(self.world_view_transform)
70
+ self.camera_center = view_inv[3][:3]
71
+
scene/colmap_loader.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import numpy as np
13
+ import collections
14
+ import struct
15
+
16
+ CameraModel = collections.namedtuple(
17
+ "CameraModel", ["model_id", "model_name", "num_params"])
18
+ Camera = collections.namedtuple(
19
+ "Camera", ["id", "model", "width", "height", "params"])
20
+ BaseImage = collections.namedtuple(
21
+ "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
22
+ Point3D = collections.namedtuple(
23
+ "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
24
+ CAMERA_MODELS = {
25
+ CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
26
+ CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
27
+ CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
28
+ CameraModel(model_id=3, model_name="RADIAL", num_params=5),
29
+ CameraModel(model_id=4, model_name="OPENCV", num_params=8),
30
+ CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
31
+ CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
32
+ CameraModel(model_id=7, model_name="FOV", num_params=5),
33
+ CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
34
+ CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
35
+ CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
36
+ }
37
+ CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
38
+ for camera_model in CAMERA_MODELS])
39
+ CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
40
+ for camera_model in CAMERA_MODELS])
41
+
42
+
43
+ def qvec2rotmat(qvec):
44
+ return np.array([
45
+ [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
46
+ 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
47
+ 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
48
+ [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
49
+ 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
50
+ 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
51
+ [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
52
+ 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
53
+ 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
54
+
55
+ def rotmat2qvec(R):
56
+ Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
57
+ K = np.array([
58
+ [Rxx - Ryy - Rzz, 0, 0, 0],
59
+ [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
60
+ [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
61
+ [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
62
+ eigvals, eigvecs = np.linalg.eigh(K)
63
+ qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
64
+ if qvec[0] < 0:
65
+ qvec *= -1
66
+ return qvec
67
+
68
+ class Image(BaseImage):
69
+ def qvec2rotmat(self):
70
+ return qvec2rotmat(self.qvec)
71
+
72
+ def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
73
+ """Read and unpack the next bytes from a binary file.
74
+ :param fid:
75
+ :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
76
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
77
+ :param endian_character: Any of {@, =, <, >, !}
78
+ :return: Tuple of read and unpacked values.
79
+ """
80
+ data = fid.read(num_bytes)
81
+ return struct.unpack(endian_character + format_char_sequence, data)
82
+
83
+ def read_points3D_text(path):
84
+ """
85
+ see: src/base/reconstruction.cc
86
+ void Reconstruction::ReadPoints3DText(const std::string& path)
87
+ void Reconstruction::WritePoints3DText(const std::string& path)
88
+ """
89
+ xyzs = None
90
+ rgbs = None
91
+ errors = None
92
+ num_points = 0
93
+ with open(path, "r") as fid:
94
+ while True:
95
+ line = fid.readline()
96
+ if not line:
97
+ break
98
+ line = line.strip()
99
+ if len(line) > 0 and line[0] != "#":
100
+ num_points += 1
101
+
102
+
103
+ xyzs = np.empty((num_points, 3))
104
+ rgbs = np.empty((num_points, 3))
105
+ errors = np.empty((num_points, 1))
106
+ count = 0
107
+ with open(path, "r") as fid:
108
+ while True:
109
+ line = fid.readline()
110
+ if not line:
111
+ break
112
+ line = line.strip()
113
+ if len(line) > 0 and line[0] != "#":
114
+ elems = line.split()
115
+ xyz = np.array(tuple(map(float, elems[1:4])))
116
+ rgb = np.array(tuple(map(int, elems[4:7])))
117
+ error = np.array(float(elems[7]))
118
+ xyzs[count] = xyz
119
+ rgbs[count] = rgb
120
+ errors[count] = error
121
+ count += 1
122
+
123
+ return xyzs, rgbs, errors
124
+
125
+ def read_points3D_binary(path_to_model_file):
126
+ """
127
+ see: src/base/reconstruction.cc
128
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
129
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
130
+ """
131
+
132
+
133
+ with open(path_to_model_file, "rb") as fid:
134
+ num_points = read_next_bytes(fid, 8, "Q")[0]
135
+
136
+ xyzs = np.empty((num_points, 3))
137
+ rgbs = np.empty((num_points, 3))
138
+ errors = np.empty((num_points, 1))
139
+
140
+ for p_id in range(num_points):
141
+ binary_point_line_properties = read_next_bytes(
142
+ fid, num_bytes=43, format_char_sequence="QdddBBBd")
143
+ xyz = np.array(binary_point_line_properties[1:4])
144
+ rgb = np.array(binary_point_line_properties[4:7])
145
+ error = np.array(binary_point_line_properties[7])
146
+ track_length = read_next_bytes(
147
+ fid, num_bytes=8, format_char_sequence="Q")[0]
148
+ track_elems = read_next_bytes(
149
+ fid, num_bytes=8*track_length,
150
+ format_char_sequence="ii"*track_length)
151
+ xyzs[p_id] = xyz
152
+ rgbs[p_id] = rgb
153
+ errors[p_id] = error
154
+ return xyzs, rgbs, errors
155
+
156
+ def read_intrinsics_text(path):
157
+ """
158
+ Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
159
+ """
160
+ cameras = {}
161
+ with open(path, "r") as fid:
162
+ while True:
163
+ line = fid.readline()
164
+ if not line:
165
+ break
166
+ line = line.strip()
167
+ if len(line) > 0 and line[0] != "#":
168
+ elems = line.split()
169
+ camera_id = int(elems[0])
170
+ model = elems[1]
171
+ assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE"
172
+ width = int(elems[2])
173
+ height = int(elems[3])
174
+ params = np.array(tuple(map(float, elems[4:])))
175
+ cameras[camera_id] = Camera(id=camera_id, model=model,
176
+ width=width, height=height,
177
+ params=params)
178
+ return cameras
179
+
180
+ def read_extrinsics_binary(path_to_model_file):
181
+ """
182
+ see: src/base/reconstruction.cc
183
+ void Reconstruction::ReadImagesBinary(const std::string& path)
184
+ void Reconstruction::WriteImagesBinary(const std::string& path)
185
+ """
186
+ images = {}
187
+ with open(path_to_model_file, "rb") as fid:
188
+ num_reg_images = read_next_bytes(fid, 8, "Q")[0]
189
+ for _ in range(num_reg_images):
190
+ binary_image_properties = read_next_bytes(
191
+ fid, num_bytes=64, format_char_sequence="idddddddi")
192
+ image_id = binary_image_properties[0]
193
+ qvec = np.array(binary_image_properties[1:5])
194
+ tvec = np.array(binary_image_properties[5:8])
195
+ camera_id = binary_image_properties[8]
196
+ image_name = ""
197
+ current_char = read_next_bytes(fid, 1, "c")[0]
198
+ while current_char != b"\x00": # look for the ASCII 0 entry
199
+ image_name += current_char.decode("utf-8")
200
+ current_char = read_next_bytes(fid, 1, "c")[0]
201
+ num_points2D = read_next_bytes(fid, num_bytes=8,
202
+ format_char_sequence="Q")[0]
203
+ x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
204
+ format_char_sequence="ddq"*num_points2D)
205
+ xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
206
+ tuple(map(float, x_y_id_s[1::3]))])
207
+ point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
208
+ images[image_id] = Image(
209
+ id=image_id, qvec=qvec, tvec=tvec,
210
+ camera_id=camera_id, name=image_name,
211
+ xys=xys, point3D_ids=point3D_ids)
212
+ return images
213
+
214
+
215
+ def read_intrinsics_binary(path_to_model_file):
216
+ """
217
+ see: src/base/reconstruction.cc
218
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
219
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
220
+ """
221
+ cameras = {}
222
+ with open(path_to_model_file, "rb") as fid:
223
+ num_cameras = read_next_bytes(fid, 8, "Q")[0]
224
+ for _ in range(num_cameras):
225
+ camera_properties = read_next_bytes(
226
+ fid, num_bytes=24, format_char_sequence="iiQQ")
227
+ camera_id = camera_properties[0]
228
+ model_id = camera_properties[1]
229
+ model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
230
+ width = camera_properties[2]
231
+ height = camera_properties[3]
232
+ num_params = CAMERA_MODEL_IDS[model_id].num_params
233
+ params = read_next_bytes(fid, num_bytes=8*num_params,
234
+ format_char_sequence="d"*num_params)
235
+ cameras[camera_id] = Camera(id=camera_id,
236
+ model=model_name,
237
+ width=width,
238
+ height=height,
239
+ params=np.array(params))
240
+ assert len(cameras) == num_cameras
241
+ return cameras
242
+
243
+
244
+ def read_extrinsics_text(path):
245
+ """
246
+ Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
247
+ """
248
+ images = {}
249
+ with open(path, "r") as fid:
250
+ while True:
251
+ line = fid.readline()
252
+ if not line:
253
+ break
254
+ line = line.strip()
255
+ if len(line) > 0 and line[0] != "#":
256
+ elems = line.split()
257
+ image_id = int(elems[0])
258
+ qvec = np.array(tuple(map(float, elems[1:5])))
259
+ tvec = np.array(tuple(map(float, elems[5:8])))
260
+ camera_id = int(elems[8])
261
+ image_name = elems[9]
262
+ elems = fid.readline().split()
263
+ xys = np.column_stack([tuple(map(float, elems[0::3])),
264
+ tuple(map(float, elems[1::3]))])
265
+ point3D_ids = np.array(tuple(map(int, elems[2::3])))
266
+ images[image_id] = Image(
267
+ id=image_id, qvec=qvec, tvec=tvec,
268
+ camera_id=camera_id, name=image_name,
269
+ xys=xys, point3D_ids=point3D_ids)
270
+ return images
271
+
272
+
273
+ def read_colmap_bin_array(path):
274
+ """
275
+ Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py
276
+
277
+ :param path: path to the colmap binary file.
278
+ :return: nd array with the floating point values in the value
279
+ """
280
+ with open(path, "rb") as fid:
281
+ width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1,
282
+ usecols=(0, 1, 2), dtype=int)
283
+ fid.seek(0)
284
+ num_delimiter = 0
285
+ byte = fid.read(1)
286
+ while True:
287
+ if byte == b"&":
288
+ num_delimiter += 1
289
+ if num_delimiter >= 3:
290
+ break
291
+ byte = fid.read(1)
292
+ array = np.fromfile(fid, np.float32)
293
+ array = array.reshape((width, height, channels), order="F")
294
+ return np.transpose(array, (1, 0, 2)).squeeze()
scene/dataset_readers.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import os
13
+ import sys
14
+ from PIL import Image
15
+ from typing import NamedTuple
16
+ from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \
17
+ read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text
18
+ from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
19
+ import numpy as np
20
+ import json
21
+ from pathlib import Path
22
+ from plyfile import PlyData, PlyElement
23
+ from utils.sh_utils import SH2RGB
24
+ from scene.gaussian_model import BasicPointCloud
25
+ # from utils.camera_utils import generate_ellipse_path_from_camera_infos
26
+
27
+ class CameraInfo(NamedTuple):
28
+ uid: int
29
+ R: np.array
30
+ T: np.array
31
+ FovY: np.array
32
+ FovX: np.array
33
+ image: np.array
34
+ image_path: str
35
+ image_name: str
36
+ width: int
37
+ height: int
38
+
39
+
40
+ class SceneInfo(NamedTuple):
41
+ point_cloud: BasicPointCloud
42
+ train_cameras: list
43
+ test_cameras: list
44
+ # render_cameras: list
45
+ nerf_normalization: dict
46
+ ply_path: str
47
+ train_poses: list
48
+ test_poses: list
49
+
50
+ def getNerfppNorm(cam_info):
51
+ def get_center_and_diag(cam_centers):
52
+ cam_centers = np.hstack(cam_centers)
53
+ avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
54
+ center = avg_cam_center
55
+ dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
56
+ diagonal = np.max(dist)
57
+ return center.flatten(), diagonal
58
+
59
+ cam_centers = []
60
+
61
+ for cam in cam_info:
62
+ W2C = getWorld2View2(cam.R, cam.T)
63
+ C2W = np.linalg.inv(W2C)
64
+ cam_centers.append(C2W[:3, 3:4])
65
+
66
+ center, diagonal = get_center_and_diag(cam_centers)
67
+ radius = diagonal * 1.1
68
+
69
+ translate = -center
70
+
71
+ return {"translate": translate, "radius": radius}
72
+
73
+ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder, eval):
74
+
75
+ cam_infos = []
76
+ poses=[]
77
+ for idx, key in enumerate(cam_extrinsics):
78
+ sys.stdout.write('\r')
79
+ # the exact output you're looking for:
80
+ sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics)))
81
+ sys.stdout.flush()
82
+
83
+ if eval:
84
+ extr = cam_extrinsics[key]
85
+ intr = cam_intrinsics[1]
86
+ uid = idx+1
87
+
88
+ else:
89
+ extr = cam_extrinsics[key]
90
+ intr = cam_intrinsics[extr.camera_id]
91
+ uid = intr.id
92
+
93
+ height = intr.height
94
+ width = intr.width
95
+ R = np.transpose(qvec2rotmat(extr.qvec))
96
+ T = np.array(extr.tvec)
97
+ pose = np.vstack((np.hstack((R, T.reshape(3,-1))),np.array([[0, 0, 0, 1]])))
98
+ poses.append(pose)
99
+ if intr.model=="SIMPLE_PINHOLE":
100
+ focal_length_x = intr.params[0]
101
+ FovY = focal2fov(focal_length_x, height)
102
+ FovX = focal2fov(focal_length_x, width)
103
+ elif intr.model=="PINHOLE":
104
+ focal_length_x = intr.params[0]
105
+ focal_length_y = intr.params[1]
106
+ FovY = focal2fov(focal_length_y, height)
107
+ FovX = focal2fov(focal_length_x, width)
108
+ else:
109
+ assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
110
+
111
+
112
+ if eval:
113
+ tmp = os.path.dirname(os.path.dirname(os.path.join(images_folder)))
114
+ all_images_folder = os.path.join(tmp, 'images')
115
+ image_path = os.path.join(all_images_folder, os.path.basename(extr.name))
116
+ else:
117
+ image_path = os.path.join(images_folder, os.path.basename(extr.name))
118
+ image_name = os.path.basename(image_path).split(".")[0]
119
+ image = Image.open(image_path)
120
+
121
+
122
+ cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
123
+ image_path=image_path, image_name=image_name, width=width, height=height)
124
+
125
+ cam_infos.append(cam_info)
126
+ sys.stdout.write('\n')
127
+ return cam_infos, poses
128
+
129
+ # For interpolated video, open when only render interpolated video
130
+ def readColmapCamerasInterp(cam_extrinsics, cam_intrinsics, images_folder, model_path, cam_traj):
131
+
132
+ # pose_interpolated_path = model_path + 'pose/pose_interpolated.npy'
133
+ pose_interpolated_path = model_path + f'pose/pose_{cam_traj}.npy'
134
+ pose_interpolated = np.load(pose_interpolated_path)
135
+ intr = cam_intrinsics[1]
136
+
137
+ cam_infos = []
138
+ poses=[]
139
+ for idx, pose_npy in enumerate(pose_interpolated):
140
+ sys.stdout.write('\r')
141
+ sys.stdout.write("Reading camera {}/{}".format(idx+1, pose_interpolated.shape[0]))
142
+ sys.stdout.flush()
143
+
144
+ extr = pose_npy
145
+ intr = intr
146
+ height = intr.height
147
+ width = intr.width
148
+
149
+ uid = idx
150
+ R = extr[:3, :3].transpose()
151
+ T = extr[:3, 3]
152
+ pose = np.vstack((np.hstack((R, T.reshape(3,-1))),np.array([[0, 0, 0, 1]])))
153
+ # print(uid)
154
+ # print(pose.shape)
155
+ # pose = np.linalg.inv(pose)
156
+ poses.append(pose)
157
+ if intr.model=="SIMPLE_PINHOLE":
158
+ focal_length_x = intr.params[0]
159
+ FovY = focal2fov(focal_length_x, height)
160
+ FovX = focal2fov(focal_length_x, width)
161
+ elif intr.model=="PINHOLE":
162
+ focal_length_x = intr.params[0]
163
+ focal_length_y = intr.params[1]
164
+ FovY = focal2fov(focal_length_y, height)
165
+ FovX = focal2fov(focal_length_x, width)
166
+ else:
167
+ assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
168
+
169
+ images_list = os.listdir(os.path.join(images_folder))
170
+ image_name_0 = images_list[0]
171
+ image_name = str(idx).zfill(4)
172
+ image = Image.open(images_folder + '/' + image_name_0)
173
+
174
+ cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
175
+ image_path=images_folder, image_name=image_name, width=width, height=height)
176
+ cam_infos.append(cam_info)
177
+
178
+ sys.stdout.write('\n')
179
+ return cam_infos, poses
180
+
181
+
182
+ def fetchPly(path):
183
+ plydata = PlyData.read(path)
184
+ vertices = plydata['vertex']
185
+ positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
186
+ colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
187
+ normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
188
+ features = None
189
+ feat_keys = [key for key in vertices.data.dtype.names if key.startswith('feat_')]
190
+ if feat_keys:
191
+ features = np.vstack([vertices[key] for key in feat_keys]).T
192
+ return BasicPointCloud(points=positions, colors=colors, normals=normals, features=features)
193
+
194
+ def storePly(path, xyz, rgb):
195
+ # Define the dtype for the structured array
196
+ dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
197
+ ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
198
+ ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
199
+
200
+ normals = np.zeros_like(xyz)
201
+
202
+ elements = np.empty(xyz.shape[0], dtype=dtype)
203
+ attributes = np.concatenate((xyz, normals, rgb), axis=1)
204
+ elements[:] = list(map(tuple, attributes))
205
+
206
+ # Create the PlyData object and write to file
207
+ vertex_element = PlyElement.describe(elements, 'vertex')
208
+ ply_data = PlyData([vertex_element])
209
+ ply_data.write(path)
210
+
211
+ def readColmapSceneInfo(path, images, eval, args, opt, llffhold=2):
212
+ # try:
213
+ # cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
214
+ # cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
215
+ # cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
216
+ # cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
217
+ # except:
218
+
219
+ ##### For initializing test pose using PCD_Registration
220
+ if eval and opt.get_video==False:
221
+ print("Loading initial test pose for evaluation.")
222
+ cameras_extrinsic_file = os.path.join(path, f"test_view/sparse/0/{opt.method}", "images.txt")
223
+ else:
224
+ cameras_extrinsic_file = os.path.join(path, f"sparse/0/{opt.method}", "images.txt")
225
+
226
+ cameras_intrinsic_file = os.path.join(path, f"sparse/0/{opt.method}", "cameras.txt")
227
+ if hasattr(opt, 'feat_type') and opt.feat_type is not None:
228
+ feat_type_str = '-'.join(opt.feat_type)
229
+ if "test_view" not in cameras_extrinsic_file:
230
+ cameras_extrinsic_file = cameras_extrinsic_file.replace("images.txt", f"{feat_type_str}/images.txt")
231
+ cameras_intrinsic_file = cameras_intrinsic_file.replace("cameras.txt", f"{feat_type_str}/cameras.txt")
232
+ cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
233
+ cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)
234
+
235
+ reading_dir = "images" if images == None else images
236
+
237
+ if opt.get_video:
238
+ cam_infos_unsorted, poses = readColmapCamerasInterp(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics,
239
+ images_folder=os.path.join(path, reading_dir),
240
+ model_path=args.model_path, cam_traj=opt.cam_traj)
241
+ else:
242
+ cam_infos_unsorted, poses = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir), eval=eval)
243
+ sorting_indices = sorted(range(len(cam_infos_unsorted)), key=lambda x: cam_infos_unsorted[x].image_name)
244
+ cam_infos = [cam_infos_unsorted[i] for i in sorting_indices]
245
+ sorted_poses = [poses[i] for i in sorting_indices]
246
+ cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name)
247
+
248
+ if eval:
249
+ # train_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx+1) % llffhold != 0]
250
+ # test_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx+1) % llffhold == 0]
251
+ # train_poses = [c for idx, c in enumerate(sorted_poses) if (idx+1) % llffhold != 0]
252
+ # test_poses = [c for idx, c in enumerate(sorted_poses) if (idx+1) % llffhold == 0]
253
+
254
+ train_cam_infos = cam_infos
255
+ test_cam_infos = cam_infos
256
+ train_poses = sorted_poses
257
+ test_poses = sorted_poses
258
+
259
+ else:
260
+ train_cam_infos = cam_infos
261
+ test_cam_infos = []
262
+ train_poses = sorted_poses
263
+ test_poses = []
264
+
265
+ # render_cam_infos = generate_ellipse_path_from_camera_infos(cam_infos)
266
+
267
+ nerf_normalization = getNerfppNorm(train_cam_infos)
268
+
269
+ ply_path = os.path.join(path, f"sparse/0/{opt.method}/points3D.ply")
270
+ if hasattr(opt, 'feat_type') and opt.feat_type is not None:
271
+ ply_path = ply_path.replace("points3D.ply", f"{feat_type_str}/points3D.ply")
272
+ bin_path = os.path.join(path, "sparse/0/points3D.bin")
273
+ txt_path = os.path.join(path, "sparse/0/points3D.txt")
274
+ if not os.path.exists(ply_path):
275
+ print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
276
+ try:
277
+ xyz, rgb, _ = read_points3D_binary(bin_path)
278
+ except:
279
+ xyz, rgb, _ = read_points3D_text(txt_path)
280
+ storePly(ply_path, xyz, rgb)
281
+ try:
282
+ pcd = fetchPly(ply_path)
283
+ except:
284
+ pcd = None
285
+
286
+ # np.save("poses_family.npy", sorted_poses)
287
+ # breakpoint()
288
+ # np.save("3dpoints.npy", pcd.points)
289
+ # np.save("3dcolors.npy", pcd.colors)
290
+
291
+ scene_info = SceneInfo(point_cloud=pcd,
292
+ train_cameras=train_cam_infos,
293
+ test_cameras=test_cam_infos,
294
+ # render_cameras=render_cam_infos,
295
+ nerf_normalization=nerf_normalization,
296
+ ply_path=ply_path,
297
+ train_poses=train_poses,
298
+ test_poses=test_poses)
299
+ return scene_info
300
+
301
+ def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"):
302
+ cam_infos = []
303
+
304
+ with open(os.path.join(path, transformsfile)) as json_file:
305
+ contents = json.load(json_file)
306
+ fovx = contents["camera_angle_x"]
307
+
308
+ frames = contents["frames"]
309
+ for idx, frame in enumerate(frames):
310
+ cam_name = os.path.join(path, frame["file_path"] + extension)
311
+
312
+ # NeRF 'transform_matrix' is a camera-to-world transform
313
+ c2w = np.array(frame["transform_matrix"])
314
+ # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
315
+ c2w[:3, 1:3] *= -1
316
+
317
+ # get the world-to-camera transform and set R, T
318
+ w2c = np.linalg.inv(c2w)
319
+ R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code
320
+ T = w2c[:3, 3]
321
+
322
+ image_path = os.path.join(path, cam_name)
323
+ image_name = Path(cam_name).stem
324
+ image = Image.open(image_path)
325
+
326
+ im_data = np.array(image.convert("RGBA"))
327
+
328
+ bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
329
+
330
+ norm_data = im_data / 255.0
331
+ arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
332
+ image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
333
+
334
+ fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
335
+ FovY = fovy
336
+ FovX = fovx
337
+
338
+ cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
339
+ image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))
340
+
341
+ return cam_infos
342
+
343
+ def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
344
+ print("Reading Training Transforms")
345
+ train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension)
346
+ print("Reading Test Transforms")
347
+ test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension)
348
+
349
+ if not eval:
350
+ train_cam_infos.extend(test_cam_infos)
351
+ test_cam_infos = []
352
+
353
+ nerf_normalization = getNerfppNorm(train_cam_infos)
354
+
355
+ ply_path = os.path.join(path, "points3d.ply")
356
+ if not os.path.exists(ply_path):
357
+ # Since this data set has no colmap data, we start with random points
358
+ num_pts = 100_000
359
+ print(f"Generating random point cloud ({num_pts})...")
360
+
361
+ # We create random points inside the bounds of the synthetic Blender scenes
362
+ xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
363
+ shs = np.random.random((num_pts, 3)) / 255.0
364
+ pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))
365
+
366
+ storePly(ply_path, xyz, SH2RGB(shs) * 255)
367
+ try:
368
+ pcd = fetchPly(ply_path)
369
+ except:
370
+ pcd = None
371
+
372
+ scene_info = SceneInfo(point_cloud=pcd,
373
+ train_cameras=train_cam_infos,
374
+ test_cameras=test_cam_infos,
375
+ nerf_normalization=nerf_normalization,
376
+ ply_path=ply_path)
377
+ return scene_info
378
+
379
+ sceneLoadTypeCallbacks = {
380
+ "Colmap": readColmapSceneInfo,
381
+ "Blender" : readNerfSyntheticInfo
382
+ }
scene/gaussian_model.py ADDED
@@ -0,0 +1,830 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import torch
13
+ # from lietorch import SO3, SE3, Sim3, LieGroupParameter
14
+ import numpy as np
15
+ from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation
16
+ from torch import nn
17
+ import os
18
+ from utils.system_utils import mkdir_p
19
+ from plyfile import PlyData, PlyElement
20
+ from utils.sh_utils import RGB2SH
21
+ from simple_knn._C import distCUDA2
22
+ from utils.graphics_utils import BasicPointCloud
23
+ from utils.general_utils import strip_symmetric, build_scaling_rotation
24
+ from scipy.spatial.transform import Rotation as R
25
+ from utils.pose_utils import rotation2quad, get_tensor_from_camera
26
+ from utils.graphics_utils import getWorld2View2
27
+
28
+ import torch.nn.functional as F
29
+
30
+ def quaternion_to_rotation_matrix(quaternion):
31
+ """
32
+ Convert a quaternion to a rotation matrix.
33
+
34
+ Parameters:
35
+ - quaternion: A tensor of shape (..., 4) representing quaternions.
36
+
37
+ Returns:
38
+ - A tensor of shape (..., 3, 3) representing rotation matrices.
39
+ """
40
+ # Ensure quaternion is of float type for computation
41
+ quaternion = quaternion.float()
42
+
43
+ # Normalize the quaternion to unit length
44
+ quaternion = quaternion / quaternion.norm(p=2, dim=-1, keepdim=True)
45
+
46
+ # Extract components
47
+ w, x, y, z = quaternion[..., 0], quaternion[..., 1], quaternion[..., 2], quaternion[..., 3]
48
+
49
+ # Compute rotation matrix components
50
+ xx, yy, zz = x * x, y * y, z * z
51
+ xy, xz, yz = x * y, x * z, y * z
52
+ xw, yw, zw = x * w, y * w, z * w
53
+
54
+ # Assemble the rotation matrix
55
+ R = torch.stack([
56
+ torch.stack([1 - 2 * (yy + zz), 2 * (xy - zw), 2 * (xz + yw)], dim=-1),
57
+ torch.stack([ 2 * (xy + zw), 1 - 2 * (xx + zz), 2 * (yz - xw)], dim=-1),
58
+ torch.stack([ 2 * (xz - yw), 2 * (yz + xw), 1 - 2 * (xx + yy)], dim=-1)
59
+ ], dim=-2)
60
+
61
+ return R
62
+
63
+
64
+ class GaussianModel:
65
+
66
+ def setup_functions(self):
67
+ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
68
+ L = build_scaling_rotation(scaling_modifier * scaling, rotation)
69
+ actual_covariance = L @ L.transpose(1, 2)
70
+ symm = strip_symmetric(actual_covariance)
71
+ return symm
72
+
73
+ self.scaling_activation = torch.exp
74
+ self.scaling_inverse_activation = torch.log
75
+
76
+ self.covariance_activation = build_covariance_from_scaling_rotation
77
+
78
+ self.opacity_activation = torch.sigmoid
79
+ self.inverse_opacity_activation = inverse_sigmoid
80
+
81
+ self.rotation_activation = torch.nn.functional.normalize
82
+
83
+
84
+ def __init__(self, sh_degree : int):
85
+ # self.active_sh_degree = 0
86
+ self.active_sh_degree = sh_degree
87
+ self.max_sh_degree = sh_degree
88
+ self._xyz = torch.empty(0)
89
+ self._features_dc = torch.empty(0)
90
+ self._features_rest = torch.empty(0)
91
+ self._scaling = torch.empty(0)
92
+ self._rotation = torch.empty(0)
93
+ self._opacity = torch.empty(0)
94
+ self.max_radii2D = torch.empty(0)
95
+ self.xyz_gradient_accum = torch.empty(0)
96
+ self.denom = torch.empty(0)
97
+ self.optimizer = None
98
+ self.percent_dense = 0
99
+ self.spatial_lr_scale = 0
100
+ self.param_init = {}
101
+ self.setup_functions()
102
+
103
+ def capture(self):
104
+ return (
105
+ self.active_sh_degree,
106
+ self._xyz,
107
+ self._features_dc,
108
+ self._features_rest,
109
+ self._scaling,
110
+ self._rotation,
111
+ self._opacity,
112
+ self.max_radii2D,
113
+ self.xyz_gradient_accum,
114
+ self.denom,
115
+ self.optimizer.state_dict(),
116
+ self.spatial_lr_scale,
117
+ self.P,
118
+ )
119
+
120
+ def restore(self, model_args, training_args):
121
+ (self.active_sh_degree,
122
+ self._xyz,
123
+ self._features_dc,
124
+ self._features_rest,
125
+ self._scaling,
126
+ self._rotation,
127
+ self._opacity,
128
+ self.max_radii2D,
129
+ xyz_gradient_accum,
130
+ denom,
131
+ opt_dict,
132
+ self.spatial_lr_scale,
133
+ self.P) = model_args
134
+ self.training_setup(training_args)
135
+ self.xyz_gradient_accum = xyz_gradient_accum
136
+ self.denom = denom
137
+ self.optimizer.load_state_dict(opt_dict)
138
+
139
+ @property
140
+ def get_scaling(self):
141
+ return self.scaling_activation(self._scaling)
142
+
143
+ @property
144
+ def get_rotation(self):
145
+ return self.rotation_activation(self._rotation)
146
+
147
+ @property
148
+ def get_xyz(self):
149
+ return self._xyz
150
+
151
+ def compute_relative_world_to_camera(self, R1, t1, R2, t2):
152
+ # Create a row of zeros with a one at the end, for homogeneous coordinates
153
+ zero_row = np.array([[0, 0, 0, 1]], dtype=np.float32)
154
+
155
+ # Compute the inverse of the first extrinsic matrix
156
+ E1_inv = np.hstack([R1.T, -R1.T @ t1.reshape(-1, 1)]) # Transpose and reshape for correct dimensions
157
+ E1_inv = np.vstack([E1_inv, zero_row]) # Append the zero_row to make it a 4x4 matrix
158
+
159
+ # Compute the second extrinsic matrix
160
+ E2 = np.hstack([R2, -R2 @ t2.reshape(-1, 1)]) # No need to transpose R2
161
+ E2 = np.vstack([E2, zero_row]) # Append the zero_row to make it a 4x4 matrix
162
+
163
+ # Compute the relative transformation
164
+ E_rel = E2 @ E1_inv
165
+
166
+ return E_rel
167
+
168
+ def init_RT_seq(self, cam_list):
169
+ poses =[]
170
+ for cam in cam_list[1.0]:
171
+ p = get_tensor_from_camera(cam.world_view_transform.transpose(0, 1)) # R T -> quat t
172
+ poses.append(p)
173
+ poses = torch.stack(poses)
174
+ self.P = poses.cuda().requires_grad_(True)
175
+ # poses_ = torch.randn(poses.detach().clone().shape, device='cuda')
176
+ # self.P = poses_.cuda().requires_grad_(True)
177
+ self.param_init['pose'] = poses.detach().clone()
178
+
179
+ def get_RT(self, idx):
180
+ pose = self.P[idx]
181
+ return pose
182
+
183
+ def get_RT_test(self, idx):
184
+ pose = self.test_P[idx]
185
+ return pose
186
+
187
+ @property
188
+ def get_features(self):
189
+ features_dc = self._features_dc
190
+ features_rest = self._features_rest
191
+ return torch.cat((features_dc, features_rest), dim=1)
192
+
193
+ @property
194
+ def get_opacity(self):
195
+ return self.opacity_activation(self._opacity)
196
+
197
+ def get_covariance(self, scaling_modifier = 1):
198
+ return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
199
+
200
+ def oneupSHdegree(self):
201
+ if self.active_sh_degree < self.max_sh_degree:
202
+ self.active_sh_degree += 1
203
+
204
+ def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
205
+ self.spatial_lr_scale = spatial_lr_scale
206
+ fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
207
+ fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
208
+ features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
209
+ features[:, :3, 0 ] = fused_color
210
+ features[:, 3:, 1:] = 0.0
211
+
212
+ print("Number of points at initialisation : ", fused_point_cloud.shape[0])
213
+
214
+ dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
215
+ scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
216
+ rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
217
+ rots[:, 0] = 1
218
+
219
+ opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
220
+
221
+ self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
222
+ self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
223
+ self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
224
+ self._scaling = nn.Parameter(scales.requires_grad_(True))
225
+ self._rotation = nn.Parameter(rots.requires_grad_(True))
226
+ self._opacity = nn.Parameter(opacities.requires_grad_(True))
227
+ self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
228
+
229
+ self.param_init.update({
230
+ 'xyz': fused_point_cloud.detach().clone(),
231
+ 'f_dc': features[:,:,0:1].transpose(1, 2).contiguous().detach().clone(),
232
+ 'f_rest': features[:,:,1:].transpose(1, 2).contiguous().detach().clone(),
233
+ 'opacity': opacities.detach().clone(),
234
+ 'scaling': scales.detach().clone(),
235
+ 'rotation': rots.detach().clone(),
236
+ })
237
+ def training_setup(self, training_args):
238
+ self.percent_dense = training_args.percent_dense
239
+ self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
240
+ self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
241
+
242
+ l = [
243
+ {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
244
+ {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
245
+ {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
246
+ {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
247
+ {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"},
248
+ {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"},
249
+ ]
250
+
251
+ l_cam = [{'params': [self.P],'lr': training_args.rotation_lr*0.1, "name": "pose"},]
252
+ # l_cam = [{'params': [self.P],'lr': training_args.rotation_lr, "name": "pose"},]
253
+
254
+ l += l_cam
255
+
256
+ self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
257
+ self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
258
+ lr_final=training_args.position_lr_final*self.spatial_lr_scale,
259
+ lr_delay_mult=training_args.position_lr_delay_mult,
260
+ max_steps=training_args.position_lr_max_steps)
261
+ self.cam_scheduler_args = get_expon_lr_func(
262
+ # lr_init=0,
263
+ # lr_final=0,
264
+ lr_init=training_args.rotation_lr*0.1,
265
+ lr_final=training_args.rotation_lr*0.001,
266
+ # lr_init=training_args.position_lr_init*self.spatial_lr_scale*10,
267
+ # lr_final=training_args.position_lr_final*self.spatial_lr_scale*10,
268
+ lr_delay_mult=training_args.position_lr_delay_mult,
269
+ max_steps=1000)
270
+
271
+ def update_learning_rate(self, iteration):
272
+ ''' Learning rate scheduling per step '''
273
+ for param_group in self.optimizer.param_groups:
274
+ if param_group["name"] == "pose":
275
+ lr = self.cam_scheduler_args(iteration)
276
+ # print("pose learning rate", iteration, lr)
277
+ param_group['lr'] = lr
278
+ if param_group["name"] == "xyz":
279
+ lr = self.xyz_scheduler_args(iteration)
280
+ param_group['lr'] = lr
281
+ # return lr
282
+
283
+ def construct_list_of_attributes(self):
284
+ l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
285
+ # All channels except the 3 DC
286
+ for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
287
+ l.append('f_dc_{}'.format(i))
288
+ for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
289
+ l.append('f_rest_{}'.format(i))
290
+ l.append('opacity')
291
+ for i in range(self._scaling.shape[1]):
292
+ l.append('scale_{}'.format(i))
293
+ for i in range(self._rotation.shape[1]):
294
+ l.append('rot_{}'.format(i))
295
+ return l
296
+
297
+ def save_ply(self, path):
298
+ mkdir_p(os.path.dirname(path))
299
+
300
+ xyz = self._xyz.detach().cpu().numpy()
301
+ normals = np.zeros_like(xyz)
302
+ f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
303
+ f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
304
+ opacities = self._opacity.detach().cpu().numpy()
305
+ scale = self._scaling.detach().cpu().numpy()
306
+ rotation = self._rotation.detach().cpu().numpy()
307
+
308
+ dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
309
+
310
+ elements = np.empty(xyz.shape[0], dtype=dtype_full)
311
+ attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
312
+ elements[:] = list(map(tuple, attributes))
313
+ el = PlyElement.describe(elements, 'vertex')
314
+ PlyData([el]).write(path)
315
+
316
+ def reset_opacity(self):
317
+ opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
318
+ optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
319
+ self._opacity = optimizable_tensors["opacity"]
320
+
321
+ def load_ply(self, path):
322
+ plydata = PlyData.read(path)
323
+
324
+ xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
325
+ np.asarray(plydata.elements[0]["y"]),
326
+ np.asarray(plydata.elements[0]["z"])), axis=1)
327
+ opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
328
+
329
+ features_dc = np.zeros((xyz.shape[0], 3, 1))
330
+ features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
331
+ features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
332
+ features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
333
+
334
+ extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
335
+ extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
336
+ assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
337
+ features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
338
+ for idx, attr_name in enumerate(extra_f_names):
339
+ features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
340
+ # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
341
+ features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
342
+
343
+ scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
344
+ scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
345
+ scales = np.zeros((xyz.shape[0], len(scale_names)))
346
+ for idx, attr_name in enumerate(scale_names):
347
+ scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
348
+
349
+ rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
350
+ rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
351
+ rots = np.zeros((xyz.shape[0], len(rot_names)))
352
+ for idx, attr_name in enumerate(rot_names):
353
+ rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
354
+
355
+ self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
356
+ self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
357
+ self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
358
+ self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
359
+ self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
360
+ self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
361
+
362
+ self.active_sh_degree = self.max_sh_degree
363
+
364
+ def replace_tensor_to_optimizer(self, tensor, name):
365
+ optimizable_tensors = {}
366
+ for group in self.optimizer.param_groups:
367
+ if group["name"] == name:
368
+ # breakpoint()
369
+ stored_state = self.optimizer.state.get(group['params'][0], None)
370
+ stored_state["exp_avg"] = torch.zeros_like(tensor)
371
+ stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
372
+
373
+ del self.optimizer.state[group['params'][0]]
374
+ group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
375
+ self.optimizer.state[group['params'][0]] = stored_state
376
+
377
+ optimizable_tensors[group["name"]] = group["params"][0]
378
+ return optimizable_tensors
379
+
380
+ def _prune_optimizer(self, mask):
381
+ optimizable_tensors = {}
382
+ for group in self.optimizer.param_groups:
383
+ stored_state = self.optimizer.state.get(group['params'][0], None)
384
+ if stored_state is not None:
385
+ stored_state["exp_avg"] = stored_state["exp_avg"][mask]
386
+ stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
387
+
388
+ del self.optimizer.state[group['params'][0]]
389
+ group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
390
+ self.optimizer.state[group['params'][0]] = stored_state
391
+
392
+ optimizable_tensors[group["name"]] = group["params"][0]
393
+ else:
394
+ group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
395
+ optimizable_tensors[group["name"]] = group["params"][0]
396
+ return optimizable_tensors
397
+
398
+ def prune_points(self, mask):
399
+ valid_points_mask = ~mask
400
+ optimizable_tensors = self._prune_optimizer(valid_points_mask)
401
+
402
+ self._xyz = optimizable_tensors["xyz"]
403
+ self._features_dc = optimizable_tensors["f_dc"]
404
+ self._features_rest = optimizable_tensors["f_rest"]
405
+ self._opacity = optimizable_tensors["opacity"]
406
+ self._scaling = optimizable_tensors["scaling"]
407
+ self._rotation = optimizable_tensors["rotation"]
408
+
409
+ self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
410
+
411
+ self.denom = self.denom[valid_points_mask]
412
+ self.max_radii2D = self.max_radii2D[valid_points_mask]
413
+
414
+ def cat_tensors_to_optimizer(self, tensors_dict):
415
+ optimizable_tensors = {}
416
+ for group in self.optimizer.param_groups:
417
+ assert len(group["params"]) == 1
418
+ extension_tensor = tensors_dict[group["name"]]
419
+ stored_state = self.optimizer.state.get(group['params'][0], None)
420
+ if stored_state is not None:
421
+
422
+ stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
423
+ stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
424
+
425
+ del self.optimizer.state[group['params'][0]]
426
+ group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
427
+ self.optimizer.state[group['params'][0]] = stored_state
428
+
429
+ optimizable_tensors[group["name"]] = group["params"][0]
430
+ else:
431
+ group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
432
+ optimizable_tensors[group["name"]] = group["params"][0]
433
+
434
+ return optimizable_tensors
435
+
436
+ def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation):
437
+ d = {"xyz": new_xyz,
438
+ "f_dc": new_features_dc,
439
+ "f_rest": new_features_rest,
440
+ "opacity": new_opacities,
441
+ "scaling" : new_scaling,
442
+ "rotation" : new_rotation}
443
+
444
+ optimizable_tensors = self.cat_tensors_to_optimizer(d)
445
+ self._xyz = optimizable_tensors["xyz"]
446
+ self._features_dc = optimizable_tensors["f_dc"]
447
+ self._features_rest = optimizable_tensors["f_rest"]
448
+ self._opacity = optimizable_tensors["opacity"]
449
+ self._scaling = optimizable_tensors["scaling"]
450
+ self._rotation = optimizable_tensors["rotation"]
451
+
452
+ self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
453
+ self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
454
+ self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
455
+
456
+ def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
457
+ n_init_points = self.get_xyz.shape[0]
458
+ # Extract points that satisfy the gradient condition
459
+ padded_grad = torch.zeros((n_init_points), device="cuda")
460
+ padded_grad[:grads.shape[0]] = grads.squeeze()
461
+ selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
462
+ selected_pts_mask = torch.logical_and(selected_pts_mask,
463
+ torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)
464
+
465
+ stds = self.get_scaling[selected_pts_mask].repeat(N,1)
466
+ means =torch.zeros((stds.size(0), 3),device="cuda")
467
+ samples = torch.normal(mean=means, std=stds)
468
+ rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
469
+ new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
470
+ new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
471
+ new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
472
+ new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
473
+ new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
474
+ new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
475
+
476
+ self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation)
477
+
478
+ prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
479
+ self.prune_points(prune_filter)
480
+
481
+ def densify_and_clone(self, grads, grad_threshold, scene_extent):
482
+ # Extract points that satisfy the gradient condition
483
+ selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
484
+ selected_pts_mask = torch.logical_and(selected_pts_mask,
485
+ torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)
486
+
487
+ new_xyz = self._xyz[selected_pts_mask]
488
+ new_features_dc = self._features_dc[selected_pts_mask]
489
+ new_features_rest = self._features_rest[selected_pts_mask]
490
+ new_opacities = self._opacity[selected_pts_mask]
491
+ new_scaling = self._scaling[selected_pts_mask]
492
+ new_rotation = self._rotation[selected_pts_mask]
493
+
494
+ self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation)
495
+
496
+ def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
497
+ grads = self.xyz_gradient_accum / self.denom
498
+ grads[grads.isnan()] = 0.0
499
+
500
+ # self.densify_and_clone(grads, max_grad, extent)
501
+ # self.densify_and_split(grads, max_grad, extent)
502
+
503
+ prune_mask = (self.get_opacity < min_opacity).squeeze()
504
+ if max_screen_size:
505
+ big_points_vs = self.max_radii2D > max_screen_size
506
+ big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
507
+ prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
508
+ self.prune_points(prune_mask)
509
+
510
+ torch.cuda.empty_cache()
511
+
512
+ def add_densification_stats(self, viewspace_point_tensor, update_filter):
513
+ self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)
514
+ self.denom[update_filter] += 1
515
+
516
+
517
+ class Feat2GaussianModel(GaussianModel):
518
+
519
+ def __init__(self, sh_degree : int, feat_dim : int, gs_params_group : dict, noise_std=0):
520
+ super().__init__(sh_degree)
521
+ self.noise_std = noise_std
522
+ self.pc_feat = torch.empty(0)
523
+ self.param_init = {}
524
+ self.feat_dim = feat_dim
525
+ self.gs_params_group = gs_params_group
526
+ self.active_sh_degree = sh_degree
527
+ self.sh_coeffs = ((sh_degree + 1) ** 2) * 3-3
528
+ net_width = feat_dim
529
+ out_dim = {'xyz': 3, 'scaling': 3, 'rotation': 4, 'opacity': 1, 'f_dc': 3, 'f_rest': self.sh_coeffs}
530
+ for key in gs_params_group.get('head', []):
531
+ setattr(self, f'head_{key}', conditionalWarp(layers=[feat_dim, net_width, out_dim[key]], skip=[]).cuda())
532
+
533
+ self.param_key = {
534
+ 'xyz': '_xyz',
535
+ 'scaling': '_scaling',
536
+ 'rotation': '_rotation',
537
+ 'opacity': '_opacity',
538
+ 'f_dc': '_features_dc',
539
+ 'f_rest': '_features_rest',
540
+ 'pc_feat': 'pc_feat',
541
+ }
542
+
543
+ # ## FOR DEBUGGING
544
+ # self.head_xyz = conditionalWarp(layers=[self.feat_dim, net_width, 3], skip=[]).cuda()
545
+ # self.head_scaling = conditionalWarp(layers=[self.feat_dim, net_width, 3], skip=[]).cuda()
546
+ # self.head_rotation = conditionalWarp(layers=[self.feat_dim, net_width, 4], skip=[]).cuda()
547
+ # self.head_opacity = conditionalWarp(layers=[self.feat_dim, net_width, 1], skip=[]).cuda()
548
+ # self.head_f_dc = conditionalWarp(layers=[feat_dim, net_width, 3], skip=[]).cuda()
549
+ # self.head_f_rest = conditionalWarp(layers=[feat_dim, net_width, self.sh_coeffs], skip=[]).cuda()
550
+
551
+ def capture(self):
552
+ head_state_dicts = {f'head_{key}': getattr(self, f'head_{key}').state_dict() for key in self.gs_params_group.get('head', [])}
553
+ return (
554
+ self.active_sh_degree,
555
+ self._xyz,
556
+ self._features_dc,
557
+ self._features_rest,
558
+ self._scaling,
559
+ self._rotation,
560
+ self._opacity,
561
+ self.max_radii2D,
562
+ self.xyz_gradient_accum,
563
+ self.denom,
564
+ self.optimizer.state_dict(),
565
+ self.spatial_lr_scale,
566
+ self.P,
567
+ head_state_dicts
568
+ )
569
+
570
+ def restore(self, model_args, training_args):
571
+ (self.active_sh_degree,
572
+ self._xyz,
573
+ self._features_dc,
574
+ self._features_rest,
575
+ self._scaling,
576
+ self._rotation,
577
+ self._opacity,
578
+ self.max_radii2D,
579
+ xyz_gradient_accum,
580
+ denom,
581
+ opt_dict,
582
+ self.spatial_lr_scale,
583
+ self.P,
584
+ head_state_dicts
585
+ ) = model_args
586
+
587
+ self.training_setup(training_args)
588
+ self.xyz_gradient_accum = xyz_gradient_accum
589
+ self.denom = denom
590
+ self.optimizer.load_state_dict(opt_dict)
591
+
592
+ for key, state_dict in head_state_dicts.items():
593
+ getattr(self, key).load_state_dict(state_dict)
594
+
595
+ def inference(self):
596
+ feat_in = self.pc_feat
597
+ for key in self.gs_params_group.get('head', []):
598
+
599
+ if key == 'f_dc':
600
+ self._features_dc = getattr(self, f'head_{key}')(feat_in, self.param_init[key].view(-1, 3)).reshape(-1, 1, 3)
601
+ elif key == 'f_rest':
602
+ self._features_rest = getattr(self, f'head_{key}')(feat_in.detach(), self.param_init[key].view(-1, self.sh_coeffs)).reshape(-1, self.sh_coeffs // 3, 3)
603
+ else:
604
+ setattr(self, f'_{key}', getattr(self, f'head_{key}')(feat_in, self.param_init[key]))
605
+
606
+ # if key == 'f_dc':
607
+ # self._features_dc = getattr(self, f'head_{key}')(feat_in, self.param_init[key].view(-1, 3)).reshape(-1, 1, 3)
608
+ # self._features_dc += self.param_init[key].view(-1, 1, 3).mean(dim=0, keepdim=True)
609
+ # elif key == 'f_rest':
610
+ # self._features_rest = getattr(self, f'head_{key}')(feat_in.detach(), self.param_init[key].view(-1, self.sh_coeffs)).reshape(-1, self.sh_coeffs // 3, 3)
611
+ # self._features_rest += self.param_init[key].view(-1, self.sh_coeffs // 3, 3).mean(dim=0, keepdim=True)
612
+ # else:
613
+ # pred = getattr(self, f'head_{key}')(feat_in, self.param_init[key])
614
+ # setattr(self, f'_{key}', pred + self.param_init[key].mean(dim=0, keepdim=True))
615
+
616
+ # ## FOR DEBUGGING
617
+ # self._xyz = self.head_xyz(pred, self.param_init['xyz'])
618
+ # self._opacity = self.head_opacity(pred, self.param_init['opacity'])
619
+ # self._scaling = self.head_scaling(pred, self.param_init['scaling'])
620
+ # self._rotation = self.head_rotation(pred, self.param_init['rotation'])
621
+ # self._features_dc = self.head_f_dc(pred, self.param_init['f_dc'].view(-1,3)).reshape(-1, 1, 3)
622
+ # self._features_rest = self.head_f_rest(pred, self.param_init['f_rest'].view(-1,self.sh_coeffs)).reshape(-1, self.sh_coeffs//3, 3)
623
+
624
+ def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
625
+ self.spatial_lr_scale = spatial_lr_scale
626
+ fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
627
+ fused_point_feat = torch.tensor(np.asarray(pcd.features)).float().cuda() # get features from .PLY file
628
+ assert fused_point_feat.shape[-1] == self.feat_dim, f"Expected feature dimension {self.feat_dim}, but got {fused_point_feat.shape[-1]}"
629
+ fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
630
+ features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
631
+ features[:, :3, 0 ] = fused_color
632
+ features[:, 3:, 1:] = 0.0
633
+
634
+ print("Number of points at initialisation : ", fused_point_cloud.shape[0])
635
+
636
+ dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
637
+ scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
638
+ rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
639
+ rots[:, 0] = 1
640
+ opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
641
+ self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
642
+
643
+ self.pc_feat = fused_point_feat#.requires_grad_(True)
644
+
645
+ # fused_point_feat = torch.randn_like(fused_point_feat)
646
+ # self.pc_feat = fused_point_feat.requires_grad_(True)
647
+
648
+ self.gt_xyz = fused_point_cloud.clone()
649
+ if self.noise_std != 0:
650
+ self.noise_std /= 1000.0
651
+ torch.manual_seed(0)
652
+ torch.cuda.manual_seed(0)
653
+ noise = torch.randn_like(fused_point_cloud) * self.noise_std
654
+ fused_point_cloud += noise
655
+ # fused_point_cloud = noise + fused_point_cloud.mean(dim=0, keepdim=True)
656
+ # fused_point_cloud = torch.zeros_like(fused_point_cloud) + fused_point_cloud.mean(dim=0, keepdim=True)
657
+
658
+ param_init = {
659
+ 'xyz': fused_point_cloud,
660
+ 'scaling': scales,
661
+ 'rotation': rots,
662
+ 'opacity': opacities,
663
+ 'f_dc': features[:, :, 0:1].transpose(1, 2).contiguous(),
664
+ 'f_rest': features[:, :, 1:].transpose(1, 2).contiguous(),
665
+ 'pc_feat': fused_point_feat,
666
+ }
667
+
668
+ for key in self.gs_params_group.get('opt', []):
669
+ setattr(self, self.param_key[key], nn.Parameter(param_init[key].requires_grad_(True)))
670
+
671
+ self.param_init.update({key: value.detach().clone() for key, value in param_init.items()})
672
+
673
+ # ## FOR DEBUGGING
674
+ # self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
675
+ # self._scaling = nn.Parameter(scales.requires_grad_(True))
676
+ # self._rotation = nn.Parameter(rots.requires_grad_(True))
677
+ # self._opacity = nn.Parameter(opacities.requires_grad_(True))
678
+ # self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
679
+ # self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
680
+
681
+ # self.param_init.update({
682
+ # 'xyz': fused_point_cloud.detach().clone(),
683
+ # 'f_dc': features[:,:,0:1].transpose(1, 2).contiguous().detach().clone(),
684
+ # 'f_rest': features[:,:,1:].transpose(1, 2).contiguous().detach().clone(),
685
+ # 'opacity': opacities.detach().clone(),
686
+ # 'scaling': scales.detach().clone(),
687
+ # 'rotation': rots.detach().clone(),
688
+ # 'pc_feat':fused_point_feat.detach().clone(),
689
+ # })
690
+
691
+ def training_setup(self, training_args):
692
+ self.percent_dense = training_args.percent_dense
693
+ self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
694
+ self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
695
+
696
+ self.param_lr = {
697
+ "xyz": training_args.position_lr_init * self.spatial_lr_scale,
698
+ "f_dc": training_args.feature_lr,
699
+ "f_rest": training_args.feature_sh_lr,
700
+ "opacity": training_args.opacity_lr,
701
+ "scaling": training_args.scaling_lr,
702
+ "rotation": training_args.rotation_lr
703
+ }
704
+
705
+ warm_start_lr = 0.01
706
+ l = []
707
+ for key in self.gs_params_group.get('head', []):
708
+ l.append({
709
+ 'params': getattr(self, f'head_{key}').parameters(),
710
+ 'lr': warm_start_lr,
711
+ 'name': key
712
+ })
713
+
714
+ for key in self.gs_params_group.get('opt', []):
715
+ l.append({
716
+ 'params': [getattr(self, self.param_key[key])],
717
+ 'lr': warm_start_lr,
718
+ 'name': key
719
+ })
720
+
721
+ # ## FOR DEBUGGING
722
+ # l += [
723
+ # {'params': self.head_f_dc.parameters(), 'lr': warm_start_lr, "name": "warm_start_f_dc"},
724
+ # {'params': self.head_f_rest.parameters(), 'lr': warm_start_lr, "name": "warm_start_f_rest"},
725
+ # ]
726
+
727
+ # l = [
728
+ # {'params': self.head_xyz.parameters(), 'lr': warm_start_lr, "name": "xyz"},
729
+ # # {'params': [self._xyz], 'lr': warm_start_lr, "name": "xyz"},
730
+ # {'params': self.head_scaling.parameters(), 'lr': warm_start_lr, "name": "scaling"},
731
+ # # {'params': [self._scaling], 'lr': warm_start_lr, "name": "scaling"},
732
+ # {'params': self.head_rotation.parameters(), 'lr': warm_start_lr, "name": "rotation"},
733
+ # # {'params': [self._rotation], 'lr': warm_start_lr, "name": "rotation"},
734
+ # {'params': self.head_opacity.parameters(), 'lr': warm_start_lr, "name": "opacity"},
735
+ # # {'params': [self._opacity], 'lr': warm_start_lr, "name": "opacity"},
736
+ # # {'params': self.head_f_dc.parameters(), 'lr': warm_start_lr, "name": "f_dc"},
737
+ # {'params': [self._features_dc], 'lr': warm_start_lr, "name": "f_dc"},
738
+ # # {'params': self.head_f_rest.parameters(), 'lr': warm_start_lr, "name": "f_rest"},
739
+ # {'params': [self._features_rest], 'lr': warm_start_lr, "name": "f_rest"},
740
+ # # {'params': [self.pc_feat], 'lr': warm_start_lr, "name": "feat"},
741
+ # ]
742
+
743
+ l_cam = [{'params': [self.P],'lr': training_args.pose_lr_init, "name": "pose"},]
744
+
745
+ l += l_cam
746
+
747
+ self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
748
+ self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init * self.spatial_lr_scale,
749
+ lr_final=training_args.position_lr_final * self.spatial_lr_scale,
750
+ lr_delay_mult=training_args.position_lr_delay_mult,
751
+ max_steps=training_args.position_lr_max_steps)
752
+ self.cam_scheduler_args = get_expon_lr_func(lr_init=training_args.pose_lr_init,
753
+ lr_final=training_args.pose_lr_final,
754
+ lr_delay_mult=training_args.position_lr_delay_mult,
755
+ max_steps=1000)
756
+
757
+ self.warm_start_scheduler_args = get_expon_lr_func(lr_init=warm_start_lr,
758
+ lr_final=warm_start_lr*0.01,
759
+ max_steps=1000)
760
+
761
+ def setup_rendering_learning_rate(self, ):
762
+ ''' Setup learning rate scheduling'''
763
+ for param_group in self.optimizer.param_groups:
764
+ if param_group["name"] in self.param_lr:
765
+ param_group['lr'] = self.param_lr[param_group["name"]]
766
+ # elif param_group["name"] == "feat":
767
+ # param_group['lr'] = 1e-6
768
+
769
+ def update_warm_start_learning_rate(self, iteration):
770
+ ''' Warm start learning rate scheduling per step '''
771
+ for param_group in self.optimizer.param_groups:
772
+ lr = self.warm_start_scheduler_args(iteration)
773
+ param_group['lr'] = lr
774
+
775
+ def update_learning_rate(self, iteration):
776
+ ''' Learning rate scheduling per step '''
777
+ for param_group in self.optimizer.param_groups:
778
+ if param_group["name"] == "pose":
779
+ lr = self.cam_scheduler_args(iteration)
780
+ param_group['lr'] = lr
781
+ if param_group["name"] == "xyz":
782
+ lr = self.xyz_scheduler_args(iteration)
783
+ param_group['lr'] = lr
784
+ # return lr
785
+
786
+ class conditionalWarp(torch.nn.Module):
787
+ def __init__(self, layers, skip, skip_dim=None, res=[], freq=None, zero_init=False):
788
+ super().__init__()
789
+ self.skip = skip
790
+ self.res = res
791
+ self.freq = freq
792
+ self.mlp_warp = torch.nn.ModuleList()
793
+ L = self.get_layer_dims(layers)
794
+ for li,(k_in,k_out) in enumerate(L):
795
+ if li in self.skip: k_in += layers[-1] if skip_dim is None else skip_dim
796
+ linear = torch.nn.Linear(k_in,k_out)
797
+
798
+ # Init network output as 0
799
+ if zero_init:
800
+ if li == (len(L) - 1):
801
+ torch.nn.init.constant_(linear.weight, 0)
802
+ torch.nn.init.constant_(linear.bias, 0)
803
+
804
+ self.mlp_warp.append(linear)
805
+
806
+ def get_layer_dims(self, layers):
807
+ # return a list of tuples (k_in,k_out)
808
+ return list(zip(layers[:-1],layers[1:]))
809
+
810
+ def positional_encoding(self, input): # [B,...,N]
811
+ shape = input.shape
812
+ freq = 2**torch.arange(self.freq, dtype=torch.float32,device=input.device)*np.pi # [L]
813
+ spectrum = input[...,None]*freq # [B,...,N,L]
814
+ sin,cos = spectrum.sin(),spectrum.cos() # [B,...,N,L]
815
+ input_enc = torch.stack([sin,cos],dim=-2) # [B,...,N,2,L]
816
+ input_enc = input_enc.view(*shape[:-1],-1) # [B,...,2NL]
817
+ return input_enc
818
+
819
+ def forward(self, feat_in, color):
820
+ if self.freq != None:
821
+ feat_in = torch.cat([feat_in, self.positional_encoding(feat_in)], dim=-1)
822
+ feat = feat_in
823
+ for li,layer in enumerate(self.mlp_warp):
824
+ if li in self.skip: feat = torch.cat([feat, color],dim=-1)
825
+ if li in self.res: feat = feat + feat_in
826
+ feat = layer(feat)
827
+ if li!=len(self.mlp_warp)-1:
828
+ feat = nn.functional.relu(feat)
829
+ warp = feat
830
+ return warp
train_feat2gs.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import os
13
+ import numpy as np
14
+ import torch
15
+ from random import randint
16
+ from utils.loss_utils import l1_loss, ssim
17
+ from gaussian_renderer import render_gsplat
18
+ import sys
19
+ from scene import Scene, Feat2GaussianModel
20
+ from argparse import ArgumentParser
21
+ from arguments import ModelParams, PipelineParams, OptimizationParams
22
+ from utils.pose_utils import get_camera_from_tensor
23
+ from tqdm import tqdm
24
+
25
+ from time import perf_counter
26
+
27
+ def save_pose(path, quat_pose, train_cams, llffhold=2):
28
+ output_poses=[]
29
+ index_colmap = [cam.colmap_id for cam in train_cams]
30
+ for quat_t in quat_pose:
31
+ w2c = get_camera_from_tensor(quat_t)
32
+ output_poses.append(w2c)
33
+ colmap_poses = []
34
+ for i in range(len(index_colmap)):
35
+ ind = index_colmap.index(i+1)
36
+ bb=output_poses[ind]
37
+ bb = bb#.inverse()
38
+ colmap_poses.append(bb)
39
+ colmap_poses = torch.stack(colmap_poses).detach().cpu().numpy()
40
+ np.save(path, colmap_poses)
41
+
42
+
43
+ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from, args):
44
+ first_iter = 0
45
+ # tb_writer = prepare_output_and_logger(dataset, opt.iterations)
46
+ feat_type = '-'.join(args.feat_type)
47
+ feat_dim = args.feat_dim if feat_type not in ['iuv', 'iuvrgb'] else dataset.feat_default_dim[feat_type]
48
+ gs_params_group = dataset.gs_params_group[args.model]
49
+ gaussians = Feat2GaussianModel(dataset.sh_degree, feat_dim, gs_params_group)
50
+ scene = Scene(dataset, gaussians, opt=args, shuffle=True)
51
+ gaussians.training_setup(opt)
52
+ # if checkpoint:
53
+ # (model_params, first_iter) = torch.load(checkpoint)
54
+ # gaussians.restore(model_params, opt)
55
+ train_cams_init = scene.getTrainCameras().copy()
56
+ os.makedirs(scene.model_path + 'pose', exist_ok=True)
57
+ # save_pose(scene.model_path + 'pose' + "/pose_org.npy", gaussians.P, train_cams_init)
58
+ bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
59
+ background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
60
+
61
+ iter_start = torch.cuda.Event(enable_timing = True)
62
+ iter_end = torch.cuda.Event(enable_timing = True)
63
+
64
+ viewpoint_stack = None
65
+ ema_loss_for_log = 0.0
66
+ progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
67
+ first_iter += 1
68
+
69
+ warm_iter = 1000
70
+
71
+ start = perf_counter()
72
+ for iteration in range(first_iter, opt.iterations + 1):
73
+ # if network_gui.conn == None:
74
+ # network_gui.try_connect()
75
+ # while network_gui.conn != None:
76
+ # try:
77
+ # net_image_bytes = None
78
+ # custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive()
79
+ # if custom_cam != None:
80
+ # net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"]
81
+ # net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy())
82
+ # network_gui.send(net_image_bytes, dataset.source_path)
83
+ # if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
84
+ # break
85
+ # except Exception as e:
86
+ # network_gui.conn = None
87
+
88
+ iter_start.record()
89
+
90
+ if iteration > warm_iter:
91
+ if iteration == warm_iter+1:
92
+ gaussians.pc_feat.requires_grad_(False)
93
+ gaussians.setup_rendering_learning_rate()
94
+ gaussians.update_learning_rate(iteration - warm_iter)
95
+ else:
96
+ gaussians.update_warm_start_learning_rate(iteration)
97
+
98
+ if args.optim_pose==False:
99
+ gaussians.P.requires_grad_(False)
100
+
101
+ # (DISABLED) Every 1000 its we increase the levels of SH up to a maximum degree
102
+ # if iteration % 1000 == 0:
103
+ # gaussians.oneupSHdegree()
104
+
105
+ # Pick a random Camera
106
+ if not viewpoint_stack:
107
+ viewpoint_stack = scene.getTrainCameras().copy()
108
+ viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
109
+ pose = gaussians.get_RT(viewpoint_cam.uid)
110
+
111
+ # Render
112
+ if (iteration - 1) == debug_from:
113
+ pipe.debug = True
114
+
115
+ bg = torch.rand((3), device="cuda") if opt.random_background else background
116
+
117
+ gaussians.inference()
118
+
119
+ pretrained_loss_dict = {
120
+ 'xyz': l1_loss(gaussians._xyz, gaussians.param_init['xyz']),
121
+ # 'f_dc': l1_loss(gaussians._features_dc, gaussians.param_init['f_dc']),
122
+ # 'f_rest': l1_loss(gaussians._features_rest, gaussians.param_init['f_rest']),
123
+ 'opacity': l1_loss(gaussians._opacity, gaussians.param_init['opacity']),
124
+ 'scaling': l1_loss(gaussians._scaling, gaussians.param_init['scaling']),
125
+ 'rotation': l1_loss(gaussians._rotation, gaussians.param_init['rotation']),
126
+ # 'pose': l1_loss(gaussians.P, gaussians.param_init['pose']),
127
+ # 'focal': l1_loss(gaussians._focal_params, gaussians.param_init['focal']),
128
+ # 'pc_feat':l1_loss(gaussians.pc_feat, gaussians.param_init['pc_feat']),
129
+ }
130
+
131
+ if iteration <= warm_iter:
132
+ loss = sum(loss for key, loss in pretrained_loss_dict.items() if key in gs_params_group['head'])
133
+ Ll1 = torch.tensor(0)
134
+
135
+ if iteration > warm_iter:
136
+ render_pkg = render_gsplat(viewpoint_cam, gaussians, pipe, bg, camera_pose=pose)
137
+ image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
138
+
139
+ # Loss
140
+ gt_image = viewpoint_cam.original_image.cuda()
141
+ Ll1 = l1_loss(image, gt_image)
142
+ loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
143
+
144
+ # if feat_type in ['iuv', 'iuvrgb']:
145
+ # # Add scaling regularization for 'iuv' and 'iuvrgb' features
146
+ # # Prevents their gaussians scale from becoming too large to cause CUDA out of memory
147
+ # loss += l1_loss(gaussians._scaling, gaussians.param_init['scaling']) * 0.1
148
+
149
+ loss.backward()
150
+ iter_end.record()
151
+
152
+ with torch.no_grad():
153
+
154
+ # Progress bar
155
+ ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
156
+ if iteration % 10 == 0:
157
+ progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
158
+ progress_bar.update(10)
159
+ if iteration == opt.iterations:
160
+ progress_bar.close()
161
+
162
+ # Log and save
163
+ # training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render_gsplat, (pipe, background), pretrained_loss_dict)
164
+ if (iteration in saving_iterations):
165
+ print("\n[ITER {}] Saving Gaussians".format(iteration))
166
+ scene.save(iteration)
167
+ save_pose(scene.model_path + 'pose' + f"/pose_{iteration}.npy", gaussians.P, train_cams_init)
168
+
169
+ # (DISABLED) Densification
170
+ # if iteration < opt.densify_until_iter:
171
+ # Keep track of max radii in image-space for pruning
172
+ # gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
173
+ # gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
174
+
175
+ # if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
176
+ # size_threshold = 20 if iteration > opt.opacity_reset_interval else None
177
+ # gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
178
+
179
+ # if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
180
+ # gaussians.reset_opacity()
181
+
182
+ # Optimizer step
183
+ if iteration < opt.iterations:
184
+ gaussians.optimizer.step()
185
+ gaussians.optimizer.zero_grad(set_to_none = True)
186
+
187
+ # if (iteration in checkpoint_iterations):
188
+ # print("\n[ITER {}] Saving Checkpoint".format(iteration))
189
+ # torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
190
+
191
+ end = perf_counter()
192
+ train_time = end - start
193
+
194
+ # We commented out log&save operations, and then calculate train time.
195
+ # train_time = np.array(train_time)
196
+ # print("total_test_time_epoch: ", 1)
197
+ # print("train_time_mean: ", train_time.mean())
198
+ # print("train_time_median: ", np.median(train_time))
199
+
200
+
201
+ if __name__ == "__main__":
202
+ # Set up command line argument parser
203
+ parser = ArgumentParser(description="Training script parameters")
204
+ lp = ModelParams(parser)
205
+ op = OptimizationParams(parser)
206
+ pp = PipelineParams(parser)
207
+ parser.add_argument('--ip', type=str, default="127.0.0.1")
208
+ parser.add_argument('--port', type=int, default=6009)
209
+ parser.add_argument('--debug_from', type=int, default=-1)
210
+ parser.add_argument('--detect_anomaly', action='store_true', default=False)
211
+ parser.add_argument("--test_iterations", nargs="+", type=int,
212
+ default=[500, 800, 1000, 1500, 2000, 3000, 4000, 5000, 6000, 7_000, \
213
+ 8_000, 9_000, 10_000, 11_000, 12_000, 13_000, 14_000, 30_000])
214
+ parser.add_argument("--save_iterations", nargs="+", type=int, default=[])
215
+ parser.add_argument("--quiet", action="store_true")
216
+ parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
217
+ parser.add_argument("--start_checkpoint", type=str, default = None)
218
+ parser.add_argument("--scene", type=str, default=None)
219
+ parser.add_argument("--n_views", type=int, default=None)
220
+ parser.add_argument("--get_video", action="store_true")
221
+ parser.add_argument("--optim_pose", action="store_true")
222
+ parser.add_argument("--feat_type", type=str, nargs='*', default=None, help="Feature type(s). Multiple types can be specified for combination.")
223
+ parser.add_argument("--method", type=str, default='dust3r', help="Method of Initialization, e.g., 'dust3r' or 'mast3r'")
224
+ parser.add_argument("--feat_dim", type=int, default=None, help="Feture dimension after PCA . If None, PCA is not applied.")
225
+ parser.add_argument("--model", type=str, default='G', help="Model of Feat2gs, 'G'='geometry'/'T'='texture'/'A'='all'")
226
+
227
+ args = parser.parse_args(sys.argv[1:])
228
+ args.save_iterations.append(args.iterations)
229
+
230
+ os.makedirs(args.model_path, exist_ok=True)
231
+
232
+ print("Optimizing " + args.model_path)
233
+
234
+ # Initialize system state (RNG)
235
+ # safe_state(args.quiet)
236
+
237
+ # Start GUI server, configure and run training
238
+ # network_gui.init(args.ip, args.port)
239
+ torch.autograd.set_detect_anomaly(args.detect_anomaly)
240
+ training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, args)
241
+
242
+ # All done
243
+ print("\nTraining complete.")
utils/camera_traj_config.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ trajectory_configs = {
2
+ 'Infer': {
3
+ 'cy': {
4
+ 'up': [-1, 1],
5
+ 'arc': {
6
+ 'degree': 180.0, # Default arc degree
7
+ },
8
+ 'spiral': {
9
+ 'zrate': 0.5, # Spiral rise rate
10
+ 'rots': 1, # Number of rotations
11
+ },
12
+ 'lemniscate': {
13
+ 'degree': 45.0, # Lemniscate curve angle
14
+ },
15
+ 'wander': {
16
+ 'max_disp': 48.0, # Maximum displacement
17
+ }
18
+ },
19
+ 'paper4': {
20
+ 'up': [-1, 1],
21
+ 'arc': {
22
+ 'degree': 200.0, # Default arc degree
23
+ },
24
+ 'spiral': {
25
+ 'zrate': 0.5, # Spiral rise rate
26
+ 'rots': 1, # Number of rotations
27
+ },
28
+ 'lemniscate': {
29
+ 'degree': 80.0, # Lemniscate curve angle
30
+ },
31
+ 'wander': {
32
+ 'max_disp': 100.0, # Maximum displacement
33
+ }
34
+ },
35
+ 'cy3': {
36
+ 'up': [-1, 1],
37
+ 'arc': {
38
+ 'degree': 180.0, # Default arc degree
39
+ },
40
+ 'spiral': {
41
+ 'zrate': 0.5, # Spiral rise rate
42
+ 'rots': 1, # Number of rotations
43
+ },
44
+ 'lemniscate': {
45
+ 'degree': 45.0, # Lemniscate curve angle
46
+ },
47
+ 'wander': {
48
+ 'max_disp': 48.0, # Maximum displacement
49
+ }
50
+ },
51
+ 'cy4': {
52
+ 'up': [-1, 1],
53
+ 'arc': {
54
+ 'degree': 180.0, # Default arc degree
55
+ },
56
+ 'spiral': {
57
+ 'zrate': 0.5, # Spiral rise rate
58
+ 'rots': 1, # Number of rotations
59
+ },
60
+ 'lemniscate': {
61
+ 'degree': 45.0, # Lemniscate curve angle
62
+ },
63
+ 'wander': {
64
+ 'max_disp': 48.0, # Maximum displacement
65
+ }
66
+ },
67
+ 'coffee': {
68
+ 'up': [-1, 0],
69
+ 'arc': {
70
+ 'degree': 180.0, # Default arc degree
71
+ },
72
+ 'spiral': {
73
+ 'zrate': 0.5, # Spiral rise rate
74
+ 'rots': 1, # Number of rotations
75
+ },
76
+ 'lemniscate': {
77
+ 'degree': 80.0, # Lemniscate curve angle
78
+ },
79
+ 'wander': {
80
+ 'max_disp': 48.0, # Maximum displacement
81
+ }
82
+ },
83
+
84
+ 'plant': {
85
+ 'up': [-1, 1],
86
+ 'arc': {
87
+ 'degree': 180.0, # Default arc degree
88
+ },
89
+ 'spiral': {
90
+ 'zrate': 0.5, # Spiral rise rate
91
+ 'rots': 1, # Number of rotations
92
+ },
93
+ 'lemniscate': {
94
+ 'degree': 80.0, # Lemniscate curve angle
95
+ },
96
+ 'wander': {
97
+ 'max_disp': 48.0, # Maximum displacement
98
+ }
99
+ },
100
+
101
+ 'desk': {
102
+ 'up': [-1, 1],
103
+ 'arc': {
104
+ 'degree': 180.0, # Default arc degree
105
+ },
106
+ },
107
+
108
+ 'bread': {
109
+ 'up': [-1, 0],
110
+ 'arc': {
111
+ 'degree': 180.0, # Default arc degree
112
+ },
113
+ 'spiral': {
114
+ 'zrate': 0.5, # Spiral rise rate
115
+ 'rots': 1, # Number of rotations
116
+ },
117
+ 'lemniscate': {
118
+ 'degree': 80.0, # Lemniscate curve angle
119
+ },
120
+ 'wander': {
121
+ 'max_disp': 48.0, # Maximum displacement
122
+ }
123
+ },
124
+
125
+ 'brunch': {
126
+ 'up': [1, 1],
127
+ 'arc': {
128
+ 'degree': 180.0, # Default arc degree
129
+ },
130
+ 'spiral': {
131
+ 'zrate': 0.5, # Spiral rise rate
132
+ 'rots': 1, # Number of rotations
133
+ },
134
+ 'lemniscate': {
135
+ 'degree': 80.0, # Lemniscate curve angle
136
+ },
137
+ 'wander': {
138
+ 'max_disp': 48.0, # Maximum displacement
139
+ }
140
+ },
141
+
142
+ 'stuff': {
143
+ 'up': [-1, 0],
144
+ 'arc': {
145
+ 'degree': 180.0, # Default arc degree
146
+ },
147
+ 'spiral': {
148
+ 'zrate': 0.5, # Spiral rise rate
149
+ 'rots': 1, # Number of rotations
150
+ },
151
+ 'lemniscate': {
152
+ 'degree': 80.0, # Lemniscate curve angle
153
+ },
154
+ 'wander': {
155
+ 'max_disp': 48.0, # Maximum displacement
156
+ }
157
+ },
158
+
159
+ 'xbox': {
160
+ 'up': [-1, 1],
161
+ 'arc': {
162
+ 'degree': 180.0, # Default arc degree
163
+ },
164
+ 'spiral': {
165
+ 'zrate': 0.5, # Spiral rise rate
166
+ 'rots': 1, # Number of rotations
167
+ },
168
+ 'lemniscate': {
169
+ 'degree': 80.0, # Lemniscate curve angle
170
+ },
171
+ 'wander': {
172
+ 'max_disp': 48.0, # Maximum displacement
173
+ }
174
+ },
175
+
176
+ 'plushies': {
177
+ 'up': [-1, 1],
178
+ 'arc': {
179
+ 'degree': 120.0, # Default arc degree
180
+ },
181
+ 'spiral': {
182
+ 'zrate': 0.5, # Spiral rise rate
183
+ 'rots': 1, # Number of rotations
184
+ },
185
+ 'lemniscate': {
186
+ 'degree': 60.0, # Lemniscate curve angle
187
+ },
188
+ 'wander': {
189
+ 'max_disp': 48.0, # Maximum displacement
190
+ }
191
+ },
192
+
193
+ 'erhai': {
194
+ 'up': [-1, 1],
195
+ 'arc': {
196
+ 'degree': 180.0, # Default arc degree
197
+ },
198
+ 'lemniscate': {
199
+ 'degree': 80.0, # Lemniscate curve angle
200
+ },
201
+ 'wander': {
202
+ 'max_disp': 200.0, # Maximum displacement
203
+ }
204
+ },
205
+ 'cy_crop1': {
206
+ 'up': [-1, 0],
207
+ 'lemniscate': {
208
+ 'degree': 80.0, # Lemniscate curve angle
209
+ },
210
+ 'wander': {
211
+ 'max_disp': 60.0, # Maximum displacement
212
+ }
213
+ },
214
+
215
+ 'cy_crop': {
216
+ 'up': [-1, 1],
217
+ 'lemniscate': {
218
+ 'degree': 60.0, # Lemniscate curve angle
219
+ },
220
+ 'wander': {
221
+ 'max_disp': 48.0, # Maximum displacement
222
+ }
223
+ },
224
+
225
+ 'paper': {
226
+ 'up': [-1, 1],
227
+ 'arc': {
228
+ 'degree': 240.0, # Default arc degree
229
+ },
230
+ 'spiral': {
231
+ 'zrate': 0.5, # Spiral rise rate
232
+ 'rots': 1, # Number of rotations
233
+ },
234
+ 'lemniscate': {
235
+ 'degree': 60.0, # Lemniscate curve angle
236
+ },
237
+ 'wander': {
238
+ 'max_disp': 48.0, # Maximum displacement
239
+ }
240
+ },
241
+
242
+ 'house': {
243
+ 'up': [-1, 1],
244
+ 'arc': {
245
+ 'degree': 240.0, # Default arc degree
246
+ },
247
+ 'spiral': {
248
+ 'zrate': 0.5, # Spiral rise rate
249
+ 'rots': 1, # Number of rotations
250
+ },
251
+ 'lemniscate': {
252
+ 'degree': 60.0, # Lemniscate curve angle
253
+ },
254
+ },
255
+
256
+ 'home': {
257
+ 'up': [1, 0],
258
+ 'arc': {
259
+ 'degree': 90.0, # Default arc degree
260
+ },
261
+ 'spiral': {
262
+ 'zrate': 0.5, # Spiral rise rate
263
+ 'rots': 1, # Number of rotations
264
+ },
265
+ 'lemniscate': {
266
+ 'degree': 80.0, # Lemniscate curve angle
267
+ }
268
+ },
269
+
270
+ 'paper3': {
271
+ 'up': [1, 0],
272
+ 'arc': {
273
+ 'degree': 180.0, # Default arc degree
274
+ },
275
+ 'spiral': {
276
+ 'zrate': 0.5, # Spiral rise rate
277
+ 'rots': 1, # Number of rotations
278
+ },
279
+ 'lemniscate': {
280
+ 'degree': 80.0, # Lemniscate curve angle
281
+ },
282
+ 'wander': {
283
+ 'max_disp': 48.0, # Maximum displacement
284
+ }
285
+ },
286
+
287
+ 'castle': {
288
+ 'up': [-1, 1],
289
+ 'spiral': {
290
+ 'zrate': 2, # Spiral rise rate
291
+ 'rots': 2, # Number of rotations
292
+ },
293
+ },
294
+
295
+ 'hogwarts': {
296
+ 'up': [-1, 1],
297
+ 'spiral': {
298
+ 'zrate': 0.5, # Spiral rise rate
299
+ 'rots': 1, # Number of rotations
300
+ },
301
+ 'wander': {
302
+ 'max_disp': 100.0, # Maximum displacement
303
+ }
304
+ },
305
+ },
306
+ 'Tanks': {
307
+ 'Auditorium': {
308
+ 'up': [-1, 1],
309
+ 'arc': {
310
+ 'degree': 180.0, # Default arc degree
311
+ },
312
+ 'spiral': {
313
+ 'zrate': 0.5, # Spiral rise rate
314
+ 'rots': 1, # Number of rotations
315
+ },
316
+ 'lemniscate': {
317
+ 'degree': 30.0, # Lemniscate curve angle
318
+ },
319
+ 'wander': {
320
+ 'max_disp': 80.0, # Maximum displacement
321
+ }
322
+ },
323
+ 'Caterpillar': {
324
+ 'up': [-1, 1],
325
+ 'arc': {
326
+ 'degree': 240.0, # Default arc degree
327
+ },
328
+ 'spiral': {
329
+ 'zrate': 0.5, # Spiral rise rate
330
+ 'rots': 1, # Number of rotations
331
+ },
332
+ 'lemniscate': {
333
+ 'degree': 60.0, # Lemniscate curve angle
334
+ },
335
+ 'wander': {
336
+ 'max_disp': 48.0, # Maximum displacement
337
+ }
338
+ },
339
+ 'Family': {
340
+ 'up': [-1, 1],
341
+ 'arc': {
342
+ 'degree': 180.0, # Default arc degree
343
+ },
344
+ 'spiral': {
345
+ 'zrate': 0.5, # Spiral rise rate
346
+ 'rots': 1, # Number of rotations
347
+ },
348
+ 'lemniscate': {
349
+ 'degree': 60.0, # Lemniscate curve angle
350
+ },
351
+ 'wander': {
352
+ 'max_disp': 48.0, # Maximum displacement
353
+ }
354
+ },
355
+ 'Ignatius': {
356
+ 'up': [-1, 1],
357
+ 'arc': {
358
+ 'degree': 330.0, # Default arc degree
359
+ },
360
+ 'spiral': {
361
+ 'zrate': 0.5, # Spiral rise rate
362
+ 'rots': 1, # Number of rotations
363
+ },
364
+ 'lemniscate': {
365
+ 'degree': 80.0, # Lemniscate curve angle
366
+ },
367
+ 'wander': {
368
+ 'max_disp': 48.0, # Maximum displacement
369
+ }
370
+ },
371
+ 'Train': {
372
+ 'up': [-1, 1],
373
+ 'arc': {
374
+ 'degree': 180.0, # Default arc degree
375
+ },
376
+ 'spiral': {
377
+ 'zrate': 0.5, # Spiral rise rate
378
+ 'rots': 1, # Number of rotations
379
+ },
380
+ 'lemniscate': {
381
+ 'degree': 60.0, # Lemniscate curve angle
382
+ },
383
+ 'wander': {
384
+ 'max_disp': 48.0, # Maximum displacement
385
+ }
386
+ },
387
+
388
+ },
389
+ 'DL3DV': {
390
+ 'Center': {
391
+ 'up': [-1, 1],
392
+ 'arc': {
393
+ 'degree': 180.0, # Default arc degree
394
+ },
395
+ 'spiral': {
396
+ 'zrate': 0.5, # Spiral rise rate
397
+ 'rots': 1, # Number of rotations
398
+ },
399
+ 'lemniscate': {
400
+ 'degree': 60.0, # Lemniscate curve angle
401
+ },
402
+ 'wander': {
403
+ 'max_disp': 48.0, # Maximum displacement
404
+ }
405
+ },
406
+ 'Electrical': {
407
+ 'up': [-1, 1],
408
+ 'arc': {
409
+ 'degree': 180.0, # Default arc degree
410
+ },
411
+ 'spiral': {
412
+ 'zrate': 0.5, # Spiral rise rate
413
+ 'rots': 1, # Number of rotations
414
+ },
415
+ 'lemniscate': {
416
+ 'degree': 60.0, # Lemniscate curve angle
417
+ },
418
+ 'wander': {
419
+ 'max_disp': 48.0, # Maximum displacement
420
+ }
421
+ },
422
+ 'Museum': {
423
+ 'up': [-1, 1],
424
+ 'arc': {
425
+ 'degree': 180.0, # Default arc degree
426
+ },
427
+ 'spiral': {
428
+ 'zrate': 0.5, # Spiral rise rate
429
+ 'rots': 1, # Number of rotations
430
+ },
431
+ 'lemniscate': {
432
+ 'degree': 60.0, # Lemniscate curve angle
433
+ },
434
+ 'wander': {
435
+ 'max_disp': 48.0, # Maximum displacement
436
+ }
437
+ },
438
+ 'Supermarket2': {
439
+ 'up': [-1, 1],
440
+ 'arc': {
441
+ 'degree': 180.0, # Default arc degree
442
+ },
443
+ 'spiral': {
444
+ 'zrate': 0.5, # Spiral rise rate
445
+ 'rots': 1, # Number of rotations
446
+ },
447
+ 'lemniscate': {
448
+ 'degree': 60.0, # Lemniscate curve angle
449
+ },
450
+ 'wander': {
451
+ 'max_disp': 48.0, # Maximum displacement
452
+ }
453
+ },
454
+ 'Temple': {
455
+ 'up': [-1, 1],
456
+ 'arc': {
457
+ 'degree': 180.0, # Default arc degree
458
+ },
459
+ 'spiral': {
460
+ 'zrate': 0.5, # Spiral rise rate
461
+ 'rots': 1, # Number of rotations
462
+ },
463
+ 'lemniscate': {
464
+ 'degree': 60.0, # Lemniscate curve angle
465
+ },
466
+ 'wander': {
467
+ 'max_disp': 48.0, # Maximum displacement
468
+ }
469
+ },
470
+
471
+ },
472
+ 'MipNeRF360': {
473
+ 'garden': {
474
+ 'up': [-1, 1],
475
+ 'arc': {
476
+ 'degree': 270.0, # Default arc degree
477
+ },
478
+ 'spiral': {
479
+ 'zrate': 0.5, # Spiral rise rate
480
+ 'rots': 1, # Number of rotations
481
+ },
482
+ 'lemniscate': {
483
+ 'degree': 80.0, # Lemniscate curve angle
484
+ },
485
+ 'wander': {
486
+ 'max_disp': 48.0, # Maximum displacement
487
+ }
488
+ },
489
+ 'kitchen': {
490
+ 'up': [-1, 1],
491
+ 'arc': {
492
+ 'degree': 180.0, # Default arc degree
493
+ },
494
+ 'spiral': {
495
+ 'zrate': 0.5, # Spiral rise rate
496
+ 'rots': 1, # Number of rotations
497
+ },
498
+ 'lemniscate': {
499
+ 'degree': 80.0, # Lemniscate curve angle
500
+ },
501
+ 'wander': {
502
+ 'max_disp': 48.0, # Maximum displacement
503
+ }
504
+ },
505
+ 'room': {
506
+ 'up': [-1, 1],
507
+ 'arc': {
508
+ 'degree': 180.0, # Default arc degree
509
+ },
510
+ 'spiral': {
511
+ 'zrate': 0.5, # Spiral rise rate
512
+ 'rots': 1, # Number of rotations
513
+ },
514
+ 'lemniscate': {
515
+ 'degree': 60.0, # Lemniscate curve angle
516
+ },
517
+ 'wander': {
518
+ 'max_disp': 48.0, # Maximum displacement
519
+ }
520
+ },
521
+ },
522
+ 'MVimgNet': {
523
+ 'bench': {
524
+ 'up': [-1, 1],
525
+ 'arc': {
526
+ 'degree': 180.0, # Default arc degree
527
+ },
528
+ 'spiral': {
529
+ 'zrate': 0.5, # Spiral rise rate
530
+ 'rots': 1, # Number of rotations
531
+ },
532
+ 'lemniscate': {
533
+ 'degree': 80.0, # Lemniscate curve angle
534
+ },
535
+ 'wander': {
536
+ 'max_disp': 48.0, # Maximum displacement
537
+ }
538
+ },
539
+ 'car': {
540
+ 'up': [-1, 1],
541
+ 'arc': {
542
+ 'degree': 180.0, # Default arc degree
543
+ },
544
+ 'spiral': {
545
+ 'zrate': 0.5, # Spiral rise rate
546
+ 'rots': 1, # Number of rotations
547
+ },
548
+ 'lemniscate': {
549
+ 'degree': 60.0, # Lemniscate curve angle
550
+ },
551
+ 'wander': {
552
+ 'max_disp': 48.0, # Maximum displacement
553
+ }
554
+ },
555
+ 'suv': {
556
+ 'up': [-1, 1],
557
+ 'arc': {
558
+ 'degree': 180.0, # Default arc degree
559
+ },
560
+ 'spiral': {
561
+ 'zrate': 0.5, # Spiral rise rate
562
+ 'rots': 1, # Number of rotations
563
+ },
564
+ 'lemniscate': {
565
+ 'degree': 60.0, # Lemniscate curve angle
566
+ },
567
+ 'wander': {
568
+ 'max_disp': 48.0, # Maximum displacement
569
+ }
570
+ },
571
+
572
+ },
573
+ 'LLFF': {
574
+ 'fortress': {
575
+ 'up': [-1, 1],
576
+ 'arc': {
577
+ 'degree': 180.0, # Default arc degree
578
+ },
579
+ 'spiral': {
580
+ 'zrate': 0.5, # Spiral rise rate
581
+ 'rots': 1, # Number of rotations
582
+ },
583
+ 'lemniscate': {
584
+ 'degree': 30.0, # Lemniscate curve angle
585
+ },
586
+ 'wander': {
587
+ 'max_disp': 48.0, # Maximum displacement
588
+ }
589
+ },
590
+ 'horns': {
591
+ 'up': [-1, 1],
592
+ 'arc': {
593
+ 'degree': 180.0, # Default arc degree
594
+ },
595
+ 'spiral': {
596
+ 'zrate': 0.5, # Spiral rise rate
597
+ 'rots': 1, # Number of rotations
598
+ },
599
+ 'lemniscate': {
600
+ 'degree': 60.0, # Lemniscate curve angle
601
+ },
602
+ 'wander': {
603
+ 'max_disp': 48.0, # Maximum displacement
604
+ }
605
+ },
606
+ 'orchids': {
607
+ 'up': [-1, 1],
608
+ 'arc': {
609
+ 'degree': 180.0, # Default arc degree
610
+ },
611
+ 'spiral': {
612
+ 'zrate': 0.5, # Spiral rise rate
613
+ 'rots': 1, # Number of rotations
614
+ },
615
+ 'lemniscate': {
616
+ 'degree': 30.0, # Lemniscate curve angle
617
+ },
618
+ 'wander': {
619
+ 'max_disp': 48.0, # Maximum displacement
620
+ }
621
+ },
622
+ 'room': {
623
+ 'up': [-1, 1],
624
+ 'arc': {
625
+ 'degree': 180.0, # Default arc degree
626
+ },
627
+ 'spiral': {
628
+ 'zrate': 0.5, # Spiral rise rate
629
+ 'rots': 1, # Number of rotations
630
+ },
631
+ 'lemniscate': {
632
+ 'degree': 30.0, # Lemniscate curve angle
633
+ },
634
+ 'wander': {
635
+ 'max_disp': 48.0, # Maximum displacement
636
+ }
637
+ },
638
+ 'trex': {
639
+ 'up': [-1, 1],
640
+ 'arc': {
641
+ 'degree': 180.0, # Default arc degree
642
+ },
643
+ 'spiral': {
644
+ 'zrate': 1, # Spiral rise rate
645
+ 'rots': 1, # Number of rotations
646
+ },
647
+ 'lemniscate': {
648
+ 'degree': 30.0, # Lemniscate curve angle
649
+ },
650
+ 'wander': {
651
+ 'max_disp': 48.0, # Maximum displacement
652
+ }
653
+ },
654
+ },
655
+ }
utils/camera_utils.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ from scene.cameras import Camera
13
+ import numpy as np
14
+ from utils.general_utils import PILtoTorch
15
+ from utils.graphics_utils import fov2focal, getWorld2View2
16
+ import scipy
17
+ import matplotlib.pyplot as plt
18
+ from scipy.special import softmax
19
+ from typing import NamedTuple, List
20
+
21
+ WARNED = False
22
+
23
+ class CameraInfo(NamedTuple):
24
+ uid: int
25
+ R: np.array
26
+ T: np.array
27
+ FovY: np.array
28
+ FovX: np.array
29
+ image: np.array
30
+ image_path: str
31
+ image_name: str
32
+ width: int
33
+ height: int
34
+
35
+
36
+ def loadCam(args, id, cam_info, resolution_scale):
37
+ orig_w, orig_h = cam_info.image.size
38
+
39
+ if args.resolution in [1, 2, 4, 8]:
40
+ resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
41
+ else: # should be a type that converts to float
42
+ if args.resolution == -1:
43
+ if orig_w > 1600:
44
+ global WARNED
45
+ if not WARNED:
46
+ print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n "
47
+ "If this is not desired, please explicitly specify '--resolution/-r' as 1")
48
+ WARNED = True
49
+ global_down = orig_w / 1600
50
+ else:
51
+ global_down = 1
52
+ else:
53
+ global_down = orig_w / args.resolution
54
+
55
+ scale = float(global_down) * float(resolution_scale)
56
+ resolution = (int(orig_w / scale), int(orig_h / scale))
57
+
58
+ resized_image_rgb = PILtoTorch(cam_info.image, resolution)
59
+
60
+ gt_image = resized_image_rgb[:3, ...]
61
+ loaded_mask = None
62
+
63
+ if resized_image_rgb.shape[1] == 4:
64
+ loaded_mask = resized_image_rgb[3:4, ...]
65
+
66
+ return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
67
+ FoVx=cam_info.FovX, FoVy=cam_info.FovY,
68
+ image=gt_image, gt_alpha_mask=loaded_mask,
69
+ image_name=cam_info.image_name, uid=id, data_device=args.data_device)
70
+
71
+
72
+ def cameraList_from_camInfos(cam_infos, resolution_scale, args):
73
+ camera_list = []
74
+
75
+ for id, c in enumerate(cam_infos):
76
+ camera_list.append(loadCam(args, id, c, resolution_scale))
77
+
78
+ return camera_list
79
+
80
+
81
+ def camera_to_JSON(id, camera : Camera):
82
+ Rt = np.zeros((4, 4))
83
+ Rt[:3, :3] = camera.R.transpose()
84
+ Rt[:3, 3] = camera.T
85
+ Rt[3, 3] = 1.0
86
+
87
+ W2C = np.linalg.inv(Rt)
88
+ pos = W2C[:3, 3]
89
+ rot = W2C[:3, :3]
90
+ serializable_array_2d = [x.tolist() for x in rot]
91
+ camera_entry = {
92
+ 'id' : id,
93
+ 'img_name' : camera.image_name,
94
+ 'width' : camera.width,
95
+ 'height' : camera.height,
96
+ 'position': pos.tolist(),
97
+ 'rotation': serializable_array_2d,
98
+ 'fy' : fov2focal(camera.FovY, camera.height),
99
+ 'fx' : fov2focal(camera.FovX, camera.width)
100
+ }
101
+ return camera_entry
102
+
103
+
104
+ def transform_poses_pca(poses):
105
+ """Transforms poses so principal components lie on XYZ axes.
106
+
107
+ Args:
108
+ poses: a (N, 3, 4) array containing the cameras' camera to world transforms.
109
+
110
+ Returns:
111
+ A tuple (poses, transform), with the transformed poses and the applied
112
+ camera_to_world transforms.
113
+ """
114
+ t = poses[:, :3, 3]
115
+ t_mean = t.mean(axis=0)
116
+ t = t - t_mean
117
+
118
+ eigval, eigvec = np.linalg.eig(t.T @ t)
119
+ # Sort eigenvectors in order of largest to smallest eigenvalue.
120
+ inds = np.argsort(eigval)[::-1]
121
+ eigvec = eigvec[:, inds]
122
+ rot = eigvec.T
123
+ if np.linalg.det(rot) < 0:
124
+ rot = np.diag(np.array([1, 1, -1])) @ rot
125
+
126
+ transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1)
127
+ poses_recentered = unpad_poses(transform @ pad_poses(poses))
128
+ transform = np.concatenate([transform, np.eye(4)[3:]], axis=0)
129
+
130
+ # Flip coordinate system if z component of y-axis is negative
131
+ if poses_recentered.mean(axis=0)[2, 1] < 0:
132
+ poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered
133
+ transform = np.diag(np.array([1, -1, -1, 1])) @ transform
134
+
135
+ # Just make sure it's it in the [-1, 1]^3 cube
136
+ scale_factor = 1. / np.max(np.abs(poses_recentered[:, :3, 3]))
137
+ poses_recentered[:, :3, 3] *= scale_factor
138
+ # transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform
139
+
140
+ return poses_recentered, transform, scale_factor
141
+
142
+ def generate_interpolated_path(poses, n_interp, spline_degree=5,
143
+ smoothness=.03, rot_weight=.1):
144
+ """Creates a smooth spline path between input keyframe camera poses.
145
+
146
+ Spline is calculated with poses in format (position, lookat-point, up-point).
147
+
148
+ Args:
149
+ poses: (n, 3, 4) array of input pose keyframes.
150
+ n_interp: returned path will have n_interp * (n - 1) total poses.
151
+ spline_degree: polynomial degree of B-spline.
152
+ smoothness: parameter for spline smoothing, 0 forces exact interpolation.
153
+ rot_weight: relative weighting of rotation/translation in spline solve.
154
+
155
+ Returns:
156
+ Array of new camera poses with shape (n_interp * (n - 1), 3, 4).
157
+ """
158
+
159
+ def poses_to_points(poses, dist):
160
+ """Converts from pose matrices to (position, lookat, up) format."""
161
+ pos = poses[:, :3, -1]
162
+ lookat = poses[:, :3, -1] - dist * poses[:, :3, 2]
163
+ up = poses[:, :3, -1] + dist * poses[:, :3, 1]
164
+ return np.stack([pos, lookat, up], 1)
165
+
166
+ def points_to_poses(points):
167
+ """Converts from (position, lookat, up) format to pose matrices."""
168
+ return np.array([viewmatrix(p - l, u - p, p) for p, l, u in points])
169
+
170
+ def interp(points, n, k, s):
171
+ """Runs multidimensional B-spline interpolation on the input points."""
172
+ sh = points.shape
173
+ pts = np.reshape(points, (sh[0], -1))
174
+ k = min(k, sh[0] - 1)
175
+ tck, _ = scipy.interpolate.splprep(pts.T, k=k, s=s)
176
+ u = np.linspace(0, 1, n, endpoint=False)
177
+ new_points = np.array(scipy.interpolate.splev(u, tck))
178
+ new_points = np.reshape(new_points.T, (n, sh[1], sh[2]))
179
+ return new_points
180
+
181
+ ### Additional operation
182
+ # inter_poses = []
183
+ # for pose in poses:
184
+ # tmp_pose = np.eye(4)
185
+ # tmp_pose[:3] = np.concatenate([pose.R.T, pose.T[:, None]], 1)
186
+ # tmp_pose = np.linalg.inv(tmp_pose)
187
+ # tmp_pose[:, 1:3] *= -1
188
+ # inter_poses.append(tmp_pose)
189
+ # inter_poses = np.stack(inter_poses, 0)
190
+ # poses, transform = transform_poses_pca(inter_poses)
191
+
192
+ points = poses_to_points(poses, dist=rot_weight)
193
+ new_points = interp(points,
194
+ n_interp * (points.shape[0] - 1),
195
+ k=spline_degree,
196
+ s=smoothness)
197
+ return points_to_poses(new_points)
198
+
199
+ def viewmatrix(lookdir, up, position):
200
+ """Construct lookat view matrix."""
201
+ vec2 = normalize(lookdir)
202
+ vec0 = normalize(np.cross(up, vec2))
203
+ vec1 = normalize(np.cross(vec2, vec0))
204
+ m = np.stack([vec0, vec1, vec2, position], axis=1)
205
+ return m
206
+
207
+ def normalize(x):
208
+ """Normalization helper function."""
209
+ return x / np.linalg.norm(x)
210
+
211
+ def pad_poses(p):
212
+ """Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1]."""
213
+ bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape)
214
+ return np.concatenate([p[..., :3, :4], bottom], axis=-2)
215
+
216
+
217
+ def unpad_poses(p):
218
+ """Remove the homogeneous bottom row from [..., 4, 4] pose matrices."""
219
+ return p[..., :3, :4]
220
+
221
+ def invert_transform_poses_pca(poses_recentered, transform, scale_factor):
222
+ poses_recentered[:, :3, 3] /= scale_factor
223
+ transform_inv = np.linalg.inv(transform)
224
+ poses_original = unpad_poses(transform_inv @ pad_poses(poses_recentered))
225
+ return poses_original
226
+
227
+ def visualizer(camera_poses, colors, save_path="/mnt/data/1.png"):
228
+ fig = plt.figure()
229
+ ax = fig.add_subplot(111, projection="3d")
230
+
231
+ for pose, color in zip(camera_poses, colors):
232
+ rotation = pose[:3, :3]
233
+ translation = pose[:3, 3] # Corrected to use 3D translation component
234
+ camera_positions = np.einsum(
235
+ "...ij,...j->...i", np.linalg.inv(rotation), -translation
236
+ )
237
+
238
+ ax.scatter(
239
+ camera_positions[0],
240
+ camera_positions[1],
241
+ camera_positions[2],
242
+ c=color,
243
+ marker="o",
244
+ )
245
+
246
+ ax.set_xlabel("X")
247
+ ax.set_ylabel("Y")
248
+ ax.set_zlabel("Z")
249
+ ax.set_title("Camera Poses")
250
+
251
+ plt.savefig(save_path)
252
+ plt.close()
253
+
254
+ return save_path
255
+
256
+
257
+ def focus_point_fn(poses: np.ndarray) -> np.ndarray:
258
+ """Calculate nearest point to all focal axes in poses."""
259
+ directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4]
260
+ m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1])
261
+ mt_m = np.transpose(m, [0, 2, 1]) @ m
262
+ focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0]
263
+ return focus_pt
264
+
265
+ def interp(x, xp, fp):
266
+ # Flatten the input arrays
267
+ x_flat = x.reshape(-1, x.shape[-1])
268
+ xp_flat = xp.reshape(-1, xp.shape[-1])
269
+ fp_flat = fp.reshape(-1, fp.shape[-1])
270
+
271
+ # Perform interpolation for each set of flattened arrays
272
+ ret_flat = np.array([np.interp(xf, xpf, fpf) for xf, xpf, fpf in zip(x_flat, xp_flat, fp_flat)])
273
+
274
+ # Reshape the result to match the input shape
275
+ ret = ret_flat.reshape(x.shape)
276
+ return ret
277
+
278
+ def sorted_interp(x, xp, fp):
279
+ # Identify the location in `xp` that corresponds to each `x`.
280
+ # The final `True` index in `mask` is the start of the matching interval.
281
+ mask = x[..., None, :] >= xp[..., :, None]
282
+
283
+ def find_interval(x):
284
+ # Grab the value where `mask` switches from True to False, and vice versa.
285
+ # This approach takes advantage of the fact that `x` is sorted.
286
+ x0 = np.max(np.where(mask, x[..., None], x[..., :1, None]), -2)
287
+ x1 = np.min(np.where(~mask, x[..., None], x[..., -1:, None]), -2)
288
+ return x0, x1
289
+
290
+ fp0, fp1 = find_interval(fp)
291
+ xp0, xp1 = find_interval(xp)
292
+ with np.errstate(divide='ignore', invalid='ignore'):
293
+ offset = np.clip(np.nan_to_num((x - xp0) / (xp1 - xp0), nan=0.0), 0, 1)
294
+ ret = fp0 + offset * (fp1 - fp0)
295
+ return ret
296
+
297
+ def integrate_weights(w):
298
+ """Compute the cumulative sum of w, assuming all weight vectors sum to 1.
299
+
300
+ The output's size on the last dimension is one greater than that of the input,
301
+ because we're computing the integral corresponding to the endpoints of a step
302
+ function, not the integral of the interior/bin values.
303
+
304
+ Args:
305
+ w: Tensor, which will be integrated along the last axis. This is assumed to
306
+ sum to 1 along the last axis, and this function will (silently) break if
307
+ that is not the case.
308
+
309
+ Returns:
310
+ cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1
311
+ """
312
+ cw = np.minimum(1, np.cumsum(w[..., :-1], axis=-1))
313
+ shape = cw.shape[:-1] + (1,)
314
+ # Ensure that the CDF starts with exactly 0 and ends with exactly 1.
315
+ cw0 = np.concatenate([np.zeros(shape), cw, np.ones(shape)], axis=-1)
316
+ return cw0
317
+
318
+ def invert_cdf(u, t, w_logits, use_gpu_resampling=False):
319
+ """Invert the CDF defined by (t, w) at the points specified by u in [0, 1)."""
320
+ # Compute the PDF and CDF for each weight vector.
321
+ w = softmax(w_logits, axis=-1)
322
+ cw = integrate_weights(w)
323
+
324
+ # Interpolate into the inverse CDF.
325
+ interp_fn = interp if use_gpu_resampling else sorted_interp # Assuming these are defined using NumPy
326
+ t_new = interp_fn(u, cw, t)
327
+ return t_new
328
+
329
+ def sample(rng,
330
+ t,
331
+ w_logits,
332
+ num_samples,
333
+ single_jitter=False,
334
+ deterministic_center=False,
335
+ use_gpu_resampling=False):
336
+ """Piecewise-Constant PDF sampling from a step function.
337
+
338
+ Args:
339
+ rng: random number generator (or None for `linspace` sampling).
340
+ t: [..., num_bins + 1], bin endpoint coordinates (must be sorted)
341
+ w_logits: [..., num_bins], logits corresponding to bin weights
342
+ num_samples: int, the number of samples.
343
+ single_jitter: bool, if True, jitter every sample along each ray by the same
344
+ amount in the inverse CDF. Otherwise, jitter each sample independently.
345
+ deterministic_center: bool, if False, when `rng` is None return samples that
346
+ linspace the entire PDF. If True, skip the front and back of the linspace
347
+ so that the centers of each PDF interval are returned.
348
+ use_gpu_resampling: bool, If True this resamples the rays based on a
349
+ "gather" instruction, which is fast on GPUs but slow on TPUs. If False,
350
+ this resamples the rays based on brute-force searches, which is fast on
351
+ TPUs, but slow on GPUs.
352
+
353
+ Returns:
354
+ t_samples: jnp.ndarray(float32), [batch_size, num_samples].
355
+ """
356
+ eps = np.finfo(np.float32).eps
357
+
358
+ # Draw uniform samples.
359
+ if rng is None:
360
+ # Match the behavior of jax.random.uniform() by spanning [0, 1-eps].
361
+ if deterministic_center:
362
+ pad = 1 / (2 * num_samples)
363
+ u = np.linspace(pad, 1. - pad - eps, num_samples)
364
+ else:
365
+ u = np.linspace(0, 1. - eps, num_samples)
366
+ u = np.broadcast_to(u, t.shape[:-1] + (num_samples,))
367
+ else:
368
+ # `u` is in [0, 1) --- it can be zero, but it can never be 1.
369
+ u_max = eps + (1 - eps) / num_samples
370
+ max_jitter = (1 - u_max) / (num_samples - 1) - eps
371
+ d = 1 if single_jitter else num_samples
372
+ u = (
373
+ np.linspace(0, 1 - u_max, num_samples) +
374
+ rng.uniform(size=t.shape[:-1] + (d,), high=max_jitter))
375
+
376
+ return invert_cdf(u, t, w_logits, use_gpu_resampling=use_gpu_resampling)
377
+
378
+
379
+ def generate_ellipse_path_from_poses(poses: np.ndarray,
380
+ n_frames: int = 120,
381
+ const_speed: bool = True,
382
+ z_variation: float = 0.,
383
+ z_phase: float = 0.) -> np.ndarray:
384
+ """Generate an elliptical render path based on the given poses."""
385
+ # Calculate the focal point for the path (cameras point toward this).
386
+ center = focus_point_fn(poses)
387
+ # Path height sits at z=0 (in middle of zero-mean capture pattern).
388
+ offset = np.array([center[0], center[1], 0])
389
+
390
+ # Calculate scaling for ellipse axes based on input camera positions.
391
+ sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 100, axis=0)
392
+ # Use ellipse that is symmetric about the focal point in xy.
393
+ low = -sc + offset
394
+ high = sc + offset
395
+ # Optional height variation need not be symmetric
396
+ z_low = np.percentile((poses[:, :3, 3]), 0, axis=0)
397
+ z_high = np.percentile((poses[:, :3, 3]), 100, axis=0)
398
+
399
+ def get_positions(theta):
400
+ # Interpolate between bounds with trig functions to get ellipse in x-y.
401
+ # Optionally also interpolate in z to change camera height along path.
402
+ return np.stack([
403
+ low[0] + (high - low)[0] * (np.cos(theta) * .5 + .5),
404
+ low[1] + (high - low)[1] * (np.sin(theta) * .5 + .5),
405
+ z_variation * (z_low[2] + (z_high - z_low)[2] *
406
+ (np.cos(theta + 2 * np.pi * z_phase) * .5 + .5)),
407
+ ], -1)
408
+
409
+ theta = np.linspace(0, 2. * np.pi, n_frames + 1, endpoint=True)
410
+ positions = get_positions(theta)
411
+
412
+ if const_speed:
413
+ # Resample theta angles so that the velocity is closer to constant.
414
+ lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
415
+ theta = sample(None, theta, np.log(lengths), n_frames + 1)
416
+ positions = get_positions(theta)
417
+
418
+ # Throw away duplicated last position.
419
+ positions = positions[:-1]
420
+
421
+ # Set path's up vector to axis closest to average of input pose up vectors.
422
+ avg_up = poses[:, :3, 1].mean(0)
423
+ avg_up = avg_up / np.linalg.norm(avg_up)
424
+ ind_up = np.argmax(np.abs(avg_up))
425
+ up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up])
426
+
427
+ return np.stack([viewmatrix(p - center, up, p) for p in positions])
428
+
429
+ def generate_ellipse_path_from_camera_infos(
430
+ cam_infos,
431
+ n_frames,
432
+ const_speed=False,
433
+ z_variation=0.,
434
+ z_phase=0.
435
+ ):
436
+ print(f'Generating ellipse path from {len(cam_infos)} camera infos ...')
437
+ poses = np.array([np.linalg.inv(getWorld2View2(cam_info.R, cam_info.T))[:3, :4] for cam_info in cam_infos])
438
+ poses[:, :, 1:3] *= -1
439
+ poses, transform, scale_factor = transform_poses_pca(poses)
440
+ render_poses = generate_ellipse_path_from_poses(poses, n_frames, const_speed, z_variation, z_phase)
441
+ render_poses = invert_transform_poses_pca(render_poses, transform, scale_factor)
442
+ render_poses[:, :, 1:3] *= -1
443
+ ret_cam_infos = []
444
+ for uid, pose in enumerate(render_poses):
445
+ R = pose[:3, :3]
446
+ c2w = np.eye(4)
447
+ c2w[:3, :4] = pose
448
+ T = np.linalg.inv(c2w)[:3, 3]
449
+ cam_info = CameraInfo(
450
+ uid = uid,
451
+ R = R,
452
+ T = T,
453
+ FovY = cam_infos[0].FovY,
454
+ FovX = cam_infos[0].FovX,
455
+ # image = np.zeros_like(cam_infos[0].image),
456
+ image = cam_infos[0].image,
457
+ image_path = '',
458
+ image_name = f'{uid:05d}.png',
459
+ width = cam_infos[0].width,
460
+ height = cam_infos[0].height
461
+ )
462
+ ret_cam_infos.append(cam_info)
463
+ return ret_cam_infos
464
+
465
+ def generate_ellipse_path(
466
+ org_pose,
467
+ n_interp,
468
+ const_speed=False,
469
+ z_variation=0.,
470
+ z_phase=0.
471
+ ):
472
+ print(f'Generating ellipse path from {len(org_pose)} camera infos ...')
473
+ poses = np.array([np.linalg.inv(p)[:3, :4] for p in org_pose]) # w2c >>> c2w
474
+ poses[:, :, 1:3] *= -1
475
+ poses, transform, scale_factor = transform_poses_pca(poses)
476
+ render_poses = generate_ellipse_path_from_poses(poses, n_interp, const_speed, z_variation, z_phase)
477
+ render_poses = invert_transform_poses_pca(render_poses, transform, scale_factor)
478
+ render_poses[:, :, 1:3] *= -1 # c2w
479
+ return render_poses
480
+
481
+
utils/dust3r_utils.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import cv2
4
+ import numpy as np
5
+ import PIL.Image
6
+ from PIL.ImageOps import exif_transpose
7
+ from plyfile import PlyData, PlyElement
8
+ import torchvision.transforms as tvf
9
+ import roma
10
+ import dust3r.cloud_opt.init_im_poses as init_fun
11
+ from dust3r.cloud_opt.base_opt import global_alignment_loop
12
+ from dust3r.utils.geometry import geotrf, inv, depthmap_to_absolute_camera_coordinates
13
+ from dust3r.cloud_opt.commons import edge_str
14
+ from dust3r.utils.image import _resize_pil_image, imread_cv2
15
+ import dust3r.datasets.utils.cropping as cropping
16
+ import torch.nn.functional as F
17
+
18
+ def get_known_poses(scene):
19
+ if scene.has_im_poses:
20
+ known_poses_msk = torch.tensor([not (p.requires_grad) for p in scene.im_poses])
21
+ known_poses = scene.get_im_poses()
22
+ return known_poses_msk.sum(), known_poses_msk, known_poses
23
+ else:
24
+ return 0, None, None
25
+
26
+ def init_from_pts3d(scene, pts3d, im_focals, im_poses):
27
+ # init poses
28
+ nkp, known_poses_msk, known_poses = get_known_poses(scene)
29
+ if nkp == 1:
30
+ raise NotImplementedError("Would be simpler to just align everything afterwards on the single known pose")
31
+ elif nkp > 1:
32
+ # global rigid SE3 alignment
33
+ s, R, T = init_fun.align_multiple_poses(im_poses[known_poses_msk], known_poses[known_poses_msk])
34
+ trf = init_fun.sRT_to_4x4(s, R, T, device=known_poses.device)
35
+
36
+ # rotate everything
37
+ im_poses = trf @ im_poses
38
+ im_poses[:, :3, :3] /= s # undo scaling on the rotation part
39
+ for img_pts3d in pts3d:
40
+ img_pts3d[:] = geotrf(trf, img_pts3d)
41
+
42
+ # set all pairwise poses
43
+ for e, (i, j) in enumerate(scene.edges):
44
+ i_j = edge_str(i, j)
45
+ # compute transform that goes from cam to world
46
+ s, R, T = init_fun.rigid_points_registration(scene.pred_i[i_j], pts3d[i], conf=scene.conf_i[i_j])
47
+ scene._set_pose(scene.pw_poses, e, R, T, scale=s)
48
+
49
+ # take into account the scale normalization
50
+ s_factor = scene.get_pw_norm_scale_factor()
51
+ im_poses[:, :3, 3] *= s_factor # apply downscaling factor
52
+ for img_pts3d in pts3d:
53
+ img_pts3d *= s_factor
54
+
55
+ # init all image poses
56
+ if scene.has_im_poses:
57
+ for i in range(scene.n_imgs):
58
+ cam2world = im_poses[i]
59
+ depth = geotrf(inv(cam2world), pts3d[i])[..., 2]
60
+ scene._set_depthmap(i, depth)
61
+ scene._set_pose(scene.im_poses, i, cam2world)
62
+ if im_focals[i] is not None:
63
+ scene._set_focal(i, im_focals[i])
64
+
65
+ if scene.verbose:
66
+ print(' init loss =', float(scene()))
67
+
68
+ @torch.no_grad()
69
+ def init_minimum_spanning_tree(scene, focal_avg=False, known_focal=None, **kw):
70
+ """ Init all camera poses (image-wise and pairwise poses) given
71
+ an initial set of pairwise estimations.
72
+ """
73
+ device = scene.device
74
+ pts3d, _, im_focals, im_poses = init_fun.minimum_spanning_tree(scene.imshapes, scene.edges,
75
+ scene.pred_i, scene.pred_j, scene.conf_i, scene.conf_j, scene.im_conf, scene.min_conf_thr,
76
+ device, has_im_poses=scene.has_im_poses, verbose=scene.verbose,
77
+ **kw)
78
+
79
+ if known_focal is not None:
80
+ repeat_focal = np.repeat(known_focal, len(im_focals))
81
+ for i in range(len(im_focals)):
82
+ im_focals[i] = known_focal
83
+ scene.preset_focal(known_focals=repeat_focal)
84
+ elif focal_avg:
85
+ im_focals_avg = np.array(im_focals).mean()
86
+ for i in range(len(im_focals)):
87
+ im_focals[i] = im_focals_avg
88
+ repeat_focal = np.array(im_focals)#.cpu().numpy()
89
+ scene.preset_focal(known_focals=repeat_focal)
90
+
91
+ return init_from_pts3d(scene, pts3d, im_focals, im_poses)
92
+
93
+ @torch.cuda.amp.autocast(enabled=False)
94
+ def compute_global_alignment(scene, init=None, niter_PnP=10, focal_avg=False, known_focal=None, **kw):
95
+ if init is None:
96
+ pass
97
+ elif init == 'msp' or init == 'mst':
98
+ init_minimum_spanning_tree(scene, niter_PnP=niter_PnP, focal_avg=focal_avg, known_focal=known_focal)
99
+ elif init == 'known_poses':
100
+ init_fun.init_from_known_poses(scene, min_conf_thr=scene.min_conf_thr,
101
+ niter_PnP=niter_PnP)
102
+ else:
103
+ raise ValueError(f'bad value for {init=}')
104
+
105
+ return global_alignment_loop(scene, **kw)
106
+
107
+
108
+
109
+ def load_images(folder_or_list, size, square_ok=False):
110
+ """ open and convert all images in a list or folder to proper input format for DUSt3R
111
+ """
112
+ if isinstance(folder_or_list, str):
113
+ print(f'>> Loading images from {folder_or_list}')
114
+ root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
115
+
116
+ elif isinstance(folder_or_list, list):
117
+ print(f'>> Loading a list of {len(folder_or_list)} images')
118
+ root, folder_content = '', folder_or_list
119
+
120
+ else:
121
+ raise ValueError(f'bad {folder_or_list=} ({type(folder_or_list)})')
122
+
123
+ imgs = []
124
+ for path in folder_content:
125
+ if not path.endswith(('.jpg', '.jpeg', '.png', '.JPG')):
126
+ continue
127
+ img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert('RGB')
128
+ W1, H1 = img.size
129
+ if size == 224:
130
+ # resize short side to 224 (then crop)
131
+ img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1)))
132
+ else:
133
+ # resize long side to 512
134
+ img = _resize_pil_image(img, size)
135
+ W, H = img.size
136
+ W2 = W//16*16
137
+ H2 = H//16*16
138
+ img = np.array(img)
139
+ img = cv2.resize(img, (W2,H2), interpolation=cv2.INTER_LINEAR)
140
+ img = PIL.Image.fromarray(img)
141
+
142
+ print(f' - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}')
143
+ ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
144
+ imgs.append(dict(img=ImgNorm(img)[None], true_shape=np.int32(
145
+ [img.size[::-1]]), idx=len(imgs), instance=str(len(imgs))))
146
+
147
+ assert imgs, 'no images foud at '+root
148
+ print(f' (Found {len(imgs)} images)')
149
+ return imgs, (W1,H1)
150
+
151
+
152
+ def load_cam_mvsnet(file, interval_scale=1):
153
+ """ read camera txt file """
154
+ cam = np.zeros((2, 4, 4))
155
+ words = file.read().split()
156
+ # read extrinsic
157
+ for i in range(0, 4):
158
+ for j in range(0, 4):
159
+ extrinsic_index = 4 * i + j + 1
160
+ cam[0][i][j] = words[extrinsic_index]
161
+
162
+ # read intrinsic
163
+ for i in range(0, 3):
164
+ for j in range(0, 3):
165
+ intrinsic_index = 3 * i + j + 18
166
+ cam[1][i][j] = words[intrinsic_index]
167
+
168
+ if len(words) == 29:
169
+ cam[1][3][0] = words[27]
170
+ cam[1][3][1] = float(words[28]) * interval_scale
171
+ cam[1][3][2] = 192
172
+ cam[1][3][3] = cam[1][3][0] + cam[1][3][1] * cam[1][3][2]
173
+ elif len(words) == 30:
174
+ cam[1][3][0] = words[27]
175
+ cam[1][3][1] = float(words[28]) * interval_scale
176
+ cam[1][3][2] = words[29]
177
+ cam[1][3][3] = cam[1][3][0] + cam[1][3][1] * cam[1][3][2]
178
+ elif len(words) == 31:
179
+ cam[1][3][0] = words[27]
180
+ cam[1][3][1] = float(words[28]) * interval_scale
181
+ cam[1][3][2] = words[29]
182
+ cam[1][3][3] = words[30]
183
+ else:
184
+ cam[1][3][0] = 0
185
+ cam[1][3][1] = 0
186
+ cam[1][3][2] = 0
187
+ cam[1][3][3] = 0
188
+
189
+
190
+ extrinsic = cam[0].astype(np.float32)
191
+ intrinsic = cam[1].astype(np.float32)
192
+
193
+ return intrinsic, extrinsic
194
+
195
+
196
+ def _crop_resize_if_necessary(image, depthmap, intrinsics, resolution, rng=None, info=None):
197
+ """ This function:
198
+ - first downsizes the image with LANCZOS inteprolation,
199
+ which is better than bilinear interpolation in
200
+ """
201
+ if not isinstance(image, PIL.Image.Image):
202
+ image = PIL.Image.fromarray(image)
203
+
204
+ # downscale with lanczos interpolation so that image.size == resolution
205
+ # cropping centered on the principal point
206
+ W, H = image.size
207
+ cx, cy = intrinsics[:2, 2].round().astype(int)
208
+
209
+ # calculate min distance to margin
210
+ min_margin_x = min(cx, W-cx)
211
+ min_margin_y = min(cy, H-cy)
212
+ assert min_margin_x > W/5, f'Bad principal point in view={info}'
213
+ assert min_margin_y > H/5, f'Bad principal point in view={info}'
214
+
215
+ ## Center crop
216
+ # Crop on the principal point, make it always centered
217
+ # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
218
+ l, t = cx - min_margin_x, cy - min_margin_y
219
+ r, b = cx + min_margin_x, cy + min_margin_y
220
+ crop_bbox = (l, t, r, b)
221
+ image, depthmap, intrinsics = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
222
+
223
+ # transpose the resolution if necessary
224
+ W, H = image.size # new size
225
+ assert resolution[0] >= resolution[1]
226
+ if H > 1.1*W:
227
+ # image is portrait mode
228
+ resolution = resolution[::-1]
229
+ elif 0.9 < H/W < 1.1 and resolution[0] != resolution[1]:
230
+ # image is square, so we chose (portrait, landscape) randomly
231
+ if rng.integers(2):
232
+ resolution = resolution[::-1]
233
+
234
+ # high-quality Lanczos down-scaling
235
+ target_resolution = np.array(resolution)
236
+
237
+ ## Recale with max factor, so one of width or height might be larger than target_resolution
238
+ image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution)
239
+
240
+ # actual cropping (if necessary) with bilinear interpolation
241
+ intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=0.5)
242
+ crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution)
243
+ image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
244
+
245
+ return image, depthmap, intrinsics2
246
+
247
+
248
+ def load_images_dtu(folder_or_list, size, scene_folder):
249
+ """
250
+ Preprocessing DTU requires depth, camera param and mask.
251
+ We follow Splatt3R to compute valid_mask.
252
+ """
253
+ if isinstance(folder_or_list, str):
254
+ print(f'>> Loading images from {folder_or_list}')
255
+ root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
256
+
257
+ elif isinstance(folder_or_list, list):
258
+ print(f'>> Loading a list of {len(folder_or_list)} images')
259
+ root = os.path.dirname(folder_or_list[0]) if folder_or_list else ''
260
+ folder_content = [os.path.basename(p) for p in folder_or_list]
261
+
262
+ else:
263
+ raise ValueError(f'bad {folder_or_list=} ({type(folder_or_list)})')
264
+
265
+ depth_root = os.path.join(scene_folder, 'depths')
266
+ mask_root = os.path.join(scene_folder, 'binary_masks')
267
+ cam_root = os.path.join(scene_folder, 'cams')
268
+
269
+ imgs = []
270
+ for path in folder_content:
271
+ if not path.endswith(('.jpg', '.jpeg', '.png', '.JPG')):
272
+ continue
273
+
274
+ impath = os.path.join(root, path)
275
+ depthpath = os.path.join(depth_root, path.replace('.jpg', '.npy'))
276
+ campath = os.path.join(cam_root, path.replace('.jpg', '_cam.txt'))
277
+ maskpath = os.path.join(mask_root, path.replace('.jpg', '.png'))
278
+
279
+ rgb_image = imread_cv2(impath)
280
+ H1, W1 = rgb_image.shape[:2]
281
+ depthmap = np.load(depthpath)
282
+ depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0)
283
+
284
+ mask = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED)/255.0
285
+ mask = mask.astype(np.float32)
286
+
287
+ mask[mask>0.5] = 1.0
288
+ mask[mask<0.5] = 0.0
289
+
290
+ mask = cv2.resize(mask, (depthmap.shape[1], depthmap.shape[0]), interpolation=cv2.INTER_NEAREST)
291
+ kernel = np.ones((10, 10), np.uint8) # Define the erosion kernel
292
+ mask = cv2.erode(mask, kernel, iterations=1)
293
+ depthmap = depthmap * mask
294
+
295
+ cur_intrinsics, camera_pose = load_cam_mvsnet(open(campath, 'r'))
296
+ intrinsics = cur_intrinsics[:3, :3]
297
+ camera_pose = np.linalg.inv(camera_pose)
298
+
299
+ new_size = tuple(int(round(x*size/max(W1, H1))) for x in (W1, H1))
300
+ W, H = new_size
301
+ W2 = W//16*16
302
+ H2 = H//16*16
303
+
304
+ rgb_image, depthmap, intrinsics = _crop_resize_if_necessary(
305
+ rgb_image, depthmap, intrinsics, (W2, H2), info=impath)
306
+
307
+ print(f' - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}')
308
+ ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
309
+
310
+ img = dict(
311
+ img=ImgNorm(rgb_image)[None],
312
+ true_shape=np.int32([rgb_image.size[::-1]]),
313
+ idx=len(imgs),
314
+ instance=str(len(imgs)),
315
+ depthmap=depthmap,
316
+ camera_pose=camera_pose,
317
+ camera_intrinsics=intrinsics
318
+ )
319
+
320
+ pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**img)
321
+ img['pts3d'] = pts3d
322
+ img['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1)
323
+
324
+ imgs.append(img)
325
+
326
+
327
+ assert imgs, 'no images foud at '+root
328
+ print(f' (Found {len(imgs)} images)')
329
+ return imgs, (W1,H1)
330
+
331
+
332
+ def storePly(path, xyz, rgb, feat=None):
333
+ # Define the dtype for the structured array
334
+ dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
335
+ ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
336
+ ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
337
+
338
+ if feat is not None:
339
+ for i in range(feat.shape[1]):
340
+ dtype.append((f'feat_{i}', 'f4'))
341
+
342
+ normals = np.zeros_like(xyz)
343
+
344
+ elements = np.empty(xyz.shape[0], dtype=dtype)
345
+ attributes = np.concatenate((xyz, normals, rgb), axis=1)
346
+
347
+ if feat is not None:
348
+ attributes = np.concatenate((attributes, feat), axis=1)
349
+
350
+ elements[:] = list(map(tuple, attributes))
351
+
352
+ # Create the PlyData object and write to file
353
+ vertex_element = PlyElement.describe(elements, 'vertex')
354
+ ply_data = PlyData([vertex_element])
355
+ ply_data.write(path)
356
+
357
+ def R_to_quaternion(R):
358
+ """
359
+ Convert a rotation matrix to a quaternion.
360
+
361
+ Parameters:
362
+ - R: A 3x3 numpy array representing a rotation matrix.
363
+
364
+ Returns:
365
+ - A numpy array representing the quaternion [w, x, y, z].
366
+ """
367
+ m00, m01, m02 = R[0, 0], R[0, 1], R[0, 2]
368
+ m10, m11, m12 = R[1, 0], R[1, 1], R[1, 2]
369
+ m20, m21, m22 = R[2, 0], R[2, 1], R[2, 2]
370
+ trace = m00 + m11 + m22
371
+
372
+ if trace > 0:
373
+ s = 0.5 / np.sqrt(trace + 1.0)
374
+ w = 0.25 / s
375
+ x = (m21 - m12) * s
376
+ y = (m02 - m20) * s
377
+ z = (m10 - m01) * s
378
+ elif (m00 > m11) and (m00 > m22):
379
+ s = np.sqrt(1.0 + m00 - m11 - m22) * 2
380
+ w = (m21 - m12) / s
381
+ x = 0.25 * s
382
+ y = (m01 + m10) / s
383
+ z = (m02 + m20) / s
384
+ elif m11 > m22:
385
+ s = np.sqrt(1.0 + m11 - m00 - m22) * 2
386
+ w = (m02 - m20) / s
387
+ x = (m01 + m10) / s
388
+ y = 0.25 * s
389
+ z = (m12 + m21) / s
390
+ else:
391
+ s = np.sqrt(1.0 + m22 - m00 - m11) * 2
392
+ w = (m10 - m01) / s
393
+ x = (m02 + m20) / s
394
+ y = (m12 + m21) / s
395
+ z = 0.25 * s
396
+
397
+ return np.array([w, x, y, z])
398
+
399
+ def save_colmap_cameras(ori_size, intrinsics, camera_file):
400
+ with open(camera_file, 'w') as f:
401
+ for i, K in enumerate(intrinsics, 1): # Starting index at 1
402
+ width, height = ori_size
403
+ scale_factor_x = width/2 / K[0, 2]
404
+ scale_factor_y = height/2 / K[1, 2]
405
+ # assert scale_factor_x==scale_factor_y, "scale factor is not same for x and y"
406
+ # print(f'scale factor is not same for x {scale_factor_x} and y {scale_factor_y}')
407
+ f.write(f"{i} PINHOLE {width} {height} {K[0, 0]*scale_factor_x} {K[1, 1]*scale_factor_x} {width/2} {height/2}\n") # scale focal
408
+ # f.write(f"{i} PINHOLE {width} {height} {K[0, 0]} {K[1, 1]} {K[0, 2]} {K[1, 2]}\n")
409
+
410
+ def save_colmap_images(poses, images_file, train_img_list):
411
+ with open(images_file, 'w') as f:
412
+ for i, pose in enumerate(poses, 1): # Starting index at 1
413
+ # breakpoint()
414
+ pose = np.linalg.inv(pose)
415
+ R = pose[:3, :3]
416
+ t = pose[:3, 3]
417
+ q = R_to_quaternion(R) # Convert rotation matrix to quaternion
418
+ f.write(f"{i} {q[0]} {q[1]} {q[2]} {q[3]} {t[0]} {t[1]} {t[2]} {i} {os.path.basename(train_img_list[i-1])}\n")
419
+ f.write(f"\n")
420
+
421
+
422
+ def round_python3(number):
423
+ rounded = round(number)
424
+ if abs(number - rounded) == 0.5:
425
+ return 2.0 * round(number / 2.0)
426
+ return rounded
427
+
428
+
429
+ def rigid_points_registration(pts1, pts2, conf=None):
430
+ R, T, s = roma.rigid_points_registration(
431
+ pts1.reshape(-1, 3), pts2.reshape(-1, 3), weights=conf, compute_scaling=True)
432
+ return s, R, T # return un-scaled (R, T)
utils/feat_utils.py ADDED
@@ -0,0 +1,827 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision.transforms as tvf
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+
7
+ from dust3r.utils.device import to_numpy
8
+
9
+ from dust3r.inference import inference
10
+ from dust3r.model import AsymmetricCroCo3DStereo
11
+ from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
12
+ from utils.dust3r_utils import compute_global_alignment
13
+
14
+ from mast3r.model import AsymmetricMASt3R
15
+ from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
16
+ from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
17
+
18
+ from hydra.utils import instantiate
19
+ from omegaconf import OmegaConf
20
+
21
+
22
+ class TorchPCA(object):
23
+
24
+ def __init__(self, n_components):
25
+ self.n_components = n_components
26
+
27
+ def fit(self, X):
28
+ self.mean_ = X.mean(dim=0)
29
+ unbiased = X - self.mean_.unsqueeze(0)
30
+ U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=50)
31
+ self.components_ = V.T
32
+ self.singular_values_ = S
33
+ return self
34
+
35
+ def transform(self, X):
36
+ t0 = X - self.mean_.unsqueeze(0)
37
+ projected = t0 @ self.components_.T
38
+ return projected
39
+
40
+ def pca(stacked_feat, dim):
41
+ flattened_feats = []
42
+ for feat in stacked_feat:
43
+ H, W, C = feat.shape
44
+ feat = feat.reshape(H * W, C).detach()
45
+ flattened_feats.append(feat)
46
+ x = torch.cat(flattened_feats, dim=0)
47
+ fit_pca = TorchPCA(n_components=dim).fit(x)
48
+
49
+ projected_feats = []
50
+ for feat in stacked_feat:
51
+ H, W, C = feat.shape
52
+ feat = feat.reshape(H * W, C).detach()
53
+ x_red = fit_pca.transform(feat)
54
+ projected_feats.append(x_red.reshape(H, W, dim))
55
+ projected_feats = torch.stack(projected_feats)
56
+ return projected_feats
57
+
58
+
59
+ def upsampler(feature, upsampled_height, upsampled_width, max_chunk=None):
60
+ """
61
+ Upsample the feature tensor to the specified height and width.
62
+
63
+ Args:
64
+ - feature (torch.Tensor): The input tensor with size [B, H, W, C].
65
+ - upsampled_height (int): The target height after upsampling.
66
+ - upsampled_width (int): The target width after upsampling.
67
+
68
+ Returns:
69
+ - upsampled_feature (torch.Tensor): The upsampled tensor with size [B, upsampled_height, upsampled_width, C].
70
+ """
71
+ # Permute the tensor to [B, C, H, W] for interpolation
72
+ feature = feature.permute(0, 3, 1, 2)
73
+
74
+ # Perform the upsampling
75
+ if max_chunk:
76
+ upsampled_chunks = []
77
+
78
+ for i in range(0, len(feature), max_chunk):
79
+ chunk = feature[i:i+max_chunk]
80
+
81
+ upsampled_chunk = F.interpolate(chunk, size=(upsampled_height, upsampled_width), mode='bilinear', align_corners=False)
82
+ upsampled_chunks.append(upsampled_chunk)
83
+
84
+ upsampled_feature = torch.cat(upsampled_chunks, dim=0)
85
+ else:
86
+ upsampled_feature = F.interpolate(feature, size=(upsampled_height, upsampled_width), mode='bilinear', align_corners=False)
87
+
88
+ # Permute back to [B, H, W, C]
89
+ upsampled_feature = upsampled_feature.permute(0, 2, 3, 1)
90
+
91
+ return upsampled_feature
92
+
93
+ def visualizer(features, images, save_dir, dim=9, feat_type=None, file_name=None):
94
+ """
95
+ Visualize features and corresponding images, and save the result.
96
+
97
+ Args:
98
+ features (torch.Tensor): Feature tensor with shape [B, H, W, C].
99
+ images (list): List of dictionaries containing images with keys 'img'. Each image tensor has shape [1, 3, H, W]
100
+ and values in the range [-1, 1].
101
+ save_dir (str): Directory to save the resulting visualization.
102
+ feat_type (list): List of feature types.
103
+ file_name (str): Name of the file to save.
104
+ """
105
+ import matplotlib
106
+ matplotlib.use('Agg')
107
+ from matplotlib import pyplot as plt
108
+ import torchvision.utils as vutils
109
+
110
+ assert features.dim() == 4, "Input tensor must have 4 dimensions (B, H, W, C)"
111
+
112
+ B, H, W, C = features.size()
113
+
114
+ features = features[..., dim-9:]
115
+ # Normalize the 3-dimensional feature to range [0, 1]
116
+ features_min = features.min(dim=0, keepdim=True).values.min(dim=1, keepdim=True).values.min(dim=2, keepdim=True).values
117
+ features_max = features.max(dim=0, keepdim=True).values.max(dim=1, keepdim=True).values.max(dim=2, keepdim=True).values
118
+ features = (features - features_min) / (features_max - features_min)
119
+
120
+ ##### Save individual feature maps
121
+ # # Create subdirectory for feature visualizations
122
+ # feat_dir = os.path.join(save_dir, 'feature_maps')
123
+ # if feat_type:
124
+ # feat_dir = os.path.join(feat_dir, '-'.join(feat_type))
125
+ # os.makedirs(feat_dir, exist_ok=True)
126
+
127
+ # for i in range(B):
128
+ # # Extract and save the feature map (channels 3-6)
129
+ # feat_map = features[i, :, :, 3:6].permute(2, 0, 1) # [3, H, W]
130
+ # save_path = os.path.join(feat_dir, f'{i}_feat.png')
131
+ # vutils.save_image(feat_map, save_path, normalize=False)
132
+
133
+ # return feat_dir
134
+
135
+ ##### Save feature maps in a single image
136
+ # Set the size of the plot
137
+ fig, axes = plt.subplots(B, 4, figsize=(W*4*0.01, H*B*0.01))
138
+
139
+ for i in range(B):
140
+ # Get the original image
141
+ image_tensor = images[i]['img']
142
+ assert image_tensor.dim() == 4 and image_tensor.size(0) == 1 and image_tensor.size(1) == 3, "Image tensor must have shape [1, 3, H, W]"
143
+ image = image_tensor.squeeze(0).permute(1, 2, 0).numpy() # Convert to (H, W, 3)
144
+
145
+ # Scale image values from [-1, 1] to [0, 1]
146
+ image = (image + 1) / 2
147
+
148
+ ax = axes[i, 0] if B > 1 else axes[0]
149
+ ax.imshow(image)
150
+ ax.axis('off')
151
+
152
+ # Visualize each 3-dimensional feature
153
+ for j in range(3):
154
+ ax = axes[i, j+1] if B > 1 else axes[j+1]
155
+ if j * 3 < min(C, dim): # Check if the feature channels are available
156
+ feature_to_plot = features[i, :, :, j*3:(j+1)*3].cpu().numpy()
157
+ ax.imshow(feature_to_plot)
158
+ else: # Plot white image if features are not available
159
+ ax.imshow(torch.ones(H, W, 3).numpy())
160
+ ax.axis('off')
161
+
162
+ # Reduce margins and spaces between images
163
+ plt.subplots_adjust(wspace=0.005, hspace=0.005, left=0.01, right=0.99, top=0.99, bottom=0.01)
164
+
165
+ # Save the entire plot
166
+ if file_name is None:
167
+ file_name = f'feat_dim{dim-9}-{dim}'
168
+ if feat_type:
169
+ feat_type_str = '-'.join(feat_type)
170
+ file_name = file_name + f'_{feat_type_str}'
171
+ save_path = os.path.join(save_dir, file_name + '.png')
172
+ plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
173
+ plt.close()
174
+
175
+ return save_path
176
+
177
+
178
+ #### Open it if you visualize feature maps in Feat2GS's teaser
179
+ # import matplotlib.colors as mcolors
180
+ # from PIL import Image
181
+
182
+ # morandi_colors = [
183
+ # '#8AA2A9', '#C98474', '#F2D0A9', '#8D9F87', '#A7A7A7', '#D98E73', '#B24C33', '#5E7460', '#4A6B8A', '#B2CBC2',
184
+ # '#BBC990', '#6B859E', '#B45342', '#4E0000', '#3D0000', '#2C0000', '#1B0000', '#0A0000', '#DCAC99', '#6F936B',
185
+ # '#EBA062', '#FED273', '#9A8EB4', '#706052', '#E9E5E5', '#C4D8D2', '#F2CBBD', '#F6F9F1', '#C5CABC', '#A3968B',
186
+ # '#5C6974', '#BE7B6E', '#C67752', '#C18830', '#8C956C', '#CAC691', '#819992', '#4D797F', '#95AEB2', '#B6C4CF',
187
+ # '#84291C', '#B9551F', '#A96400', '#374B6C', '#C8B493', '#677D5D', '#9882A2', '#2D5F53', '#D2A0AC', '#658D9A',
188
+ # '#9A7265', '#EFE1D2', '#DDD8D1', '#D2C6BC', '#E3C9BC', '#B8AB9F', '#D8BEA4', '#E0D4C5', '#B8B8B6', '#D0CAC3',
189
+ # '#9AA8B5', '#BBC9B9', '#E3E8D8', '#ADB3A4', '#C5C9BB', '#A3968B', '#C2A995', '#EDE1D1', '#EDE8E1', '#EDEBE1',
190
+ # '#CFCFCC', '#AABAC6', '#DCDEE0', '#EAE5E7', '#B7AB9F', '#F7EFE3', '#DED8CF', '#ABCA99', '#C5CD8F', '#959491',
191
+ # '#FFE481', '#C18E99', '#B07C86', '#9F6A73', '#8E5860', '#DEAD44', '#CD9B31', '#BC891E', '#AB770B', '#9A6500',
192
+ # '#778144', '#666F31', '#555D1E', '#444B0B', '#333900', '#67587B', '#564668', '#684563', '#573350', '#684550',
193
+ # '#57333D', '#46212A', '#350F17', '#240004',
194
+ # ]
195
+
196
+ # def rgb_to_hsv(rgb):
197
+ # rgb = rgb.clamp(0, 1)
198
+
199
+ # cmax, cmax_idx = rgb.max(dim=-1)
200
+ # cmin = rgb.min(dim=-1).values
201
+
202
+ # diff = cmax - cmin
203
+
204
+ # h = torch.zeros_like(cmax)
205
+ # h[cmax_idx == 0] = (((rgb[..., 1] - rgb[..., 2]) / diff) % 6)[cmax_idx == 0]
206
+ # h[cmax_idx == 1] = (((rgb[..., 2] - rgb[..., 0]) / diff) + 2)[cmax_idx == 1]
207
+ # h[cmax_idx == 2] = (((rgb[..., 0] - rgb[..., 1]) / diff) + 4)[cmax_idx == 2]
208
+ # h[diff == 0] = 0 # If cmax == cmin
209
+ # h = h / 6
210
+
211
+ # s = torch.zeros_like(cmax)
212
+ # s[cmax != 0] = (diff / cmax)[cmax != 0]
213
+
214
+ # v = cmax
215
+
216
+ # return torch.stack([h, s, v], dim=-1)
217
+
218
+ # def hsv_to_rgb(hsv):
219
+ # h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]
220
+
221
+ # c = v * s
222
+ # x = c * (1 - torch.abs((h * 6) % 2 - 1))
223
+ # m = v - c
224
+
225
+ # rgb = torch.zeros_like(hsv)
226
+ # mask = (h < 1/6)
227
+ # rgb[mask] = torch.stack([c[mask], x[mask], torch.zeros_like(x[mask])], dim=-1)
228
+ # mask = (1/6 <= h) & (h < 2/6)
229
+ # rgb[mask] = torch.stack([x[mask], c[mask], torch.zeros_like(x[mask])], dim=-1)
230
+ # mask = (2/6 <= h) & (h < 3/6)
231
+ # rgb[mask] = torch.stack([torch.zeros_like(x[mask]), c[mask], x[mask]], dim=-1)
232
+ # mask = (3/6 <= h) & (h < 4/6)
233
+ # rgb[mask] = torch.stack([torch.zeros_like(x[mask]), x[mask], c[mask]], dim=-1)
234
+ # mask = (4/6 <= h) & (h < 5/6)
235
+ # rgb[mask] = torch.stack([x[mask], torch.zeros_like(x[mask]), c[mask]], dim=-1)
236
+ # mask = (5/6 <= h)
237
+ # rgb[mask] = torch.stack([c[mask], torch.zeros_like(x[mask]), x[mask]], dim=-1)
238
+
239
+ # return rgb + m.unsqueeze(-1)
240
+
241
+ # def interpolate_colors(colors, n_colors):
242
+ # # Convert colors to RGB tensor
243
+ # rgb_colors = torch.tensor([mcolors.to_rgb(color) for color in colors])
244
+
245
+ # # Convert RGB to HSV
246
+ # hsv_colors = rgb_to_hsv(rgb_colors)
247
+
248
+ # # Sort by hue
249
+ # sorted_indices = torch.argsort(hsv_colors[:, 0])
250
+ # sorted_hsv_colors = hsv_colors[sorted_indices]
251
+
252
+ # # Create interpolation indices
253
+ # indices = torch.linspace(0, len(sorted_hsv_colors) - 1, n_colors)
254
+
255
+ # # Perform interpolation
256
+ # interpolated_hsv = torch.stack([
257
+ # torch.lerp(sorted_hsv_colors[int(i)],
258
+ # sorted_hsv_colors[min(int(i) + 1, len(sorted_hsv_colors) - 1)],
259
+ # i - int(i))
260
+ # for i in indices
261
+ # ])
262
+
263
+ # # Convert interpolated result back to RGB
264
+ # interpolated_rgb = hsv_to_rgb(interpolated_hsv)
265
+
266
+ # return interpolated_rgb
267
+
268
+
269
+ # def project_to_morandi(features, morandi_colors):
270
+ # features_flat = features.reshape(-1, 3)
271
+ # distances = torch.cdist(features_flat, morandi_colors)
272
+
273
+ # # Get the indices of the closest colors
274
+ # closest_color_indices = torch.argmin(distances, dim=1)
275
+
276
+ # # Use the closest Morandi colors directly
277
+ # features_morandi = morandi_colors[closest_color_indices]
278
+
279
+ # features_morandi = features_morandi.reshape(features.shape)
280
+ # return features_morandi
281
+
282
+
283
+ # def smooth_color_transform(features, morandi_colors, smoothness=0.1):
284
+ # features_flat = features.reshape(-1, 3)
285
+ # distances = torch.cdist(features_flat, morandi_colors)
286
+
287
+ # # Calculate weights
288
+ # weights = torch.exp(-distances / smoothness)
289
+ # weights = weights / weights.sum(dim=1, keepdim=True)
290
+
291
+ # # Weighted average
292
+ # features_morandi = torch.matmul(weights, morandi_colors)
293
+
294
+ # features_morandi = features_morandi.reshape(features.shape)
295
+ # return features_morandi
296
+
297
+ # def histogram_matching(source, template):
298
+ # """
299
+ # Match the histogram of the source tensor to that of the template tensor.
300
+
301
+ # :param source: Source tensor with shape [B, H, W, 3]
302
+ # :param template: Template tensor with shape [N, 3], where N is the number of colors
303
+ # :return: Source tensor after histogram matching
304
+ # """
305
+ # def match_cumulative_cdf(source, template):
306
+ # src_values, src_indices, src_counts = torch.unique(source, return_inverse=True, return_counts=True)
307
+ # tmpl_values, tmpl_counts = torch.unique(template, return_counts=True)
308
+
309
+ # src_quantiles = torch.cumsum(src_counts.float(), 0) / source.numel()
310
+ # tmpl_quantiles = torch.cumsum(tmpl_counts.float(), 0) / template.numel()
311
+
312
+ # idx = torch.searchsorted(tmpl_quantiles, src_quantiles)
313
+ # idx = torch.clamp(idx, 1, len(tmpl_quantiles)-1)
314
+
315
+ # slope = (tmpl_values[idx] - tmpl_values[idx-1]) / (tmpl_quantiles[idx] - tmpl_quantiles[idx-1])
316
+ # interp_a_values = torch.lerp(tmpl_values[idx-1], tmpl_values[idx],
317
+ # (src_quantiles - tmpl_quantiles[idx-1]) * slope)
318
+
319
+ # return interp_a_values[src_indices].reshape(source.shape)
320
+
321
+ # matched = torch.stack([match_cumulative_cdf(source[..., i], template[:, i]) for i in range(3)], dim=-1)
322
+ # return matched
323
+
324
+ # def process_features(features):
325
+ # device = features.device
326
+
327
+ # n_colors = 1024
328
+ # morandi_colors_tensor = interpolate_colors(morandi_colors, n_colors).to(device)
329
+ # # morandi_colors_tensor = torch.tensor([mcolors.to_rgb(color) for color in morandi_colors]).to(device)
330
+
331
+ # # features_morandi = project_to_morandi(features, morandi_colors_tensor)
332
+ # # features_morandi = histogram_matching(features, morandi_colors_tensor)
333
+ # features_morandi = smooth_color_transform(features, morandi_colors_tensor, smoothness=0.05)
334
+
335
+ # return features_morandi.cpu().numpy()
336
+
337
+ # def visualizer(features, images, save_dir, dim=9, feat_type=None, file_name=None):
338
+ # import matplotlib
339
+ # matplotlib.use('Agg')
340
+
341
+ # import matplotlib.pyplot as plt
342
+ # import numpy as np
343
+ # import os
344
+
345
+ # assert features.dim() == 4, "Input tensor must have 4 dimensions (B, H, W, C)"
346
+
347
+ # B, H, W, C = features.size()
348
+
349
+ # # Ensure features have at least 3 channels for RGB visualization
350
+ # assert C >= 3, "Features must have at least 3 channels for RGB visualization"
351
+ # features = features[..., :3]
352
+
353
+ # # Normalize features to [0, 1] range
354
+ # features_min = features.min(dim=0, keepdim=True).values.min(dim=1, keepdim=True).values.min(dim=2, keepdim=True).values
355
+ # features_max = features.max(dim=0, keepdim=True).values.max(dim=1, keepdim=True).values.max(dim=2, keepdim=True).values
356
+ # features = (features - features_min) / (features_max - features_min)
357
+
358
+ # features_processed = process_features(features)
359
+
360
+ # # Create the directory structure
361
+ # vis_dir = os.path.join(save_dir, 'vis')
362
+
363
+ # if feat_type:
364
+ # feat_type_str = '-'.join(feat_type)
365
+ # vis_dir = os.path.join(vis_dir, feat_type_str)
366
+ # os.makedirs(vis_dir, exist_ok=True)
367
+
368
+ # # Save individual images for each feature map
369
+ # for i in range(B):
370
+ # if file_name is None:
371
+ # file_name = 'feat_morandi'
372
+ # save_path = os.path.join(vis_dir, f'{file_name}_{i}.png')
373
+
374
+ # # Convert to uint8 and save directly
375
+ # img = Image.fromarray((features_processed[i] * 255).astype(np.uint8))
376
+ # img.save(save_path)
377
+
378
+ # print(f"Feature maps have been saved in the directory: {vis_dir}")
379
+ # return vis_dir
380
+
381
+ def mv_visualizer(features, images, save_dir, dim=9, feat_type=None, file_name=None):
382
+ """
383
+ Visualize features and corresponding images, and save the result. (For MASt3R decoder or head features)
384
+ """
385
+ import matplotlib
386
+ matplotlib.use('Agg')
387
+ from matplotlib import pyplot as plt
388
+ import os
389
+
390
+ B, H, W, _ = features.size()
391
+ features = features[..., dim-9:]
392
+
393
+ # Normalize the 3-dimensional feature to range [0, 1]
394
+ features_min = features.min(dim=0, keepdim=True).values.min(dim=1, keepdim=True).values.min(dim=2, keepdim=True).values
395
+ features_max = features.max(dim=0, keepdim=True).values.max(dim=1, keepdim=True).values.max(dim=2, keepdim=True).values
396
+ features = (features - features_min) / (features_max - features_min)
397
+
398
+ rows = (B + 1) // 2 # Adjust rows for odd B
399
+ fig, axes = plt.subplots(rows, 8, figsize=(W*8*0.01, H*rows*0.01))
400
+
401
+ for i in range(B//2):
402
+ # Odd row: image and features
403
+ image = (images[2*i]['img'].squeeze(0).permute(1, 2, 0).numpy() + 1) / 2
404
+ axes[i, 0].imshow(image)
405
+ axes[i, 0].axis('off')
406
+ for j in range(3):
407
+ axes[i, j+1].imshow(features[2*i, :, :, j*3:(j+1)*3].cpu().numpy())
408
+ axes[i, j+1].axis('off')
409
+
410
+ # Even row: image and features
411
+ if 2*i + 1 < B:
412
+ image = (images[2*i + 1]['img'].squeeze(0).permute(1, 2, 0).numpy() + 1) / 2
413
+ axes[i, 4].imshow(image)
414
+ axes[i, 4].axis('off')
415
+ for j in range(3):
416
+ axes[i, j+5].imshow(features[2*i + 1, :, :, j*3:(j+1)*3].cpu().numpy())
417
+ axes[i, j+5].axis('off')
418
+
419
+ # Handle last row if B is odd
420
+ if B % 2 != 0:
421
+ image = (images[-1]['img'].squeeze(0).permute(1, 2, 0).numpy() + 1) / 2
422
+ axes[-1, 0].imshow(image)
423
+ axes[-1, 0].axis('off')
424
+ for j in range(3):
425
+ axes[-1, j+1].imshow(features[-1, :, :, j*3:(j+1)*3].cpu().numpy())
426
+ axes[-1, j+1].axis('off')
427
+
428
+ # Hide unused columns in last row
429
+ for j in range(4, 8):
430
+ axes[-1, j].axis('off')
431
+
432
+ plt.subplots_adjust(wspace=0.005, hspace=0.005, left=0.01, right=0.99, top=0.99, bottom=0.01)
433
+
434
+ # Save the plot
435
+ if file_name is None:
436
+ file_name = f'feat_dim{dim-9}-{dim}'
437
+ if feat_type:
438
+ feat_type_str = '-'.join(feat_type)
439
+ file_name = file_name + f'_{feat_type_str}'
440
+ save_path = os.path.join(save_dir, file_name + '.png')
441
+ plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
442
+ plt.close()
443
+
444
+ return save_path
445
+
446
+
447
+ def adjust_norm(image: torch.Tensor) -> torch.Tensor:
448
+
449
+ inv_normalize = tvf.Normalize(
450
+ mean=[-1, -1, -1],
451
+ std=[1/0.5, 1/0.5, 1/0.5]
452
+ )
453
+
454
+ correct_normalize = tvf.Normalize(
455
+ mean=[0.485, 0.456, 0.406],
456
+ std=[0.229, 0.224, 0.225]
457
+ )
458
+ image = inv_normalize(image)
459
+ image = correct_normalize(image)
460
+
461
+ return image
462
+
463
+ def adjust_midas_norm(image: torch.Tensor) -> torch.Tensor:
464
+
465
+ inv_normalize = tvf.Normalize(
466
+ mean=[-1, -1, -1],
467
+ std=[1/0.5, 1/0.5, 1/0.5]
468
+ )
469
+
470
+ correct_normalize = tvf.Normalize(
471
+ mean=[0.5, 0.5, 0.5],
472
+ std=[0.5, 0.5, 0.5]
473
+ )
474
+
475
+ image = inv_normalize(image)
476
+ image = correct_normalize(image)
477
+
478
+ return image
479
+
480
+ def adjust_clip_norm(image: torch.Tensor) -> torch.Tensor:
481
+
482
+ inv_normalize = tvf.Normalize(
483
+ mean=[-1, -1, -1],
484
+ std=[1/0.5, 1/0.5, 1/0.5]
485
+ )
486
+
487
+ correct_normalize = tvf.Normalize(
488
+ mean=[0.48145466, 0.4578275, 0.40821073],
489
+ std=[0.26862954, 0.26130258, 0.27577711]
490
+ )
491
+
492
+ image = inv_normalize(image)
493
+ image = correct_normalize(image)
494
+
495
+ return image
496
+
497
+ class UnNormalize(object):
498
+ def __init__(self, mean, std):
499
+ self.mean = mean
500
+ self.std = std
501
+
502
+ def __call__(self, image):
503
+ image2 = torch.clone(image)
504
+ if len(image2.shape) == 4:
505
+ image2 = image2.permute(1, 0, 2, 3)
506
+ for t, m, s in zip(image2, self.mean, self.std):
507
+ t.mul_(s).add_(m)
508
+ return image2.permute(1, 0, 2, 3)
509
+
510
+
511
+ norm = tvf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
512
+ unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
513
+
514
+ midas_norm = tvf.Normalize([0.5] * 3, [0.5] * 3)
515
+ midas_unnorm = UnNormalize([0.5] * 3, [0.5] * 3)
516
+
517
+
518
+ def generate_iuv(B, H, W):
519
+ i_coords = torch.arange(B).view(B, 1, 1, 1).expand(B, H, W, 1).float() / (B - 1)
520
+ u_coords = torch.linspace(0, 1, W).view(1, 1, W, 1).expand(B, H, W, 1)
521
+ v_coords = torch.linspace(0, 1, H).view(1, H, 1, 1).expand(B, H, W, 1)
522
+ iuv_coords = torch.cat([i_coords, u_coords, v_coords], dim=-1)
523
+ return iuv_coords
524
+
525
+ class FeatureExtractor:
526
+ """
527
+ Extracts and processes features from images using VFMs for per point(per pixel).
528
+ Supports multiple VFM features, dimensionality reduction, feature upsampling, and visualization.
529
+
530
+ Parameters:
531
+ images (list): List of image info.
532
+ method (str): Pointmap Init method, choose in ["dust3r", "mast3r"].
533
+ device (str): 'cuda'.
534
+ feat_type (list): VFM, choose in ["dust3r", "mast3r", "dift", "dino_b16", "dinov2_b14", "radio", "clip_b16", "mae_b16", "midas_l16", "sam_base", "iuvrgb"].
535
+ feat_dim (int): PCA dimensions.
536
+ img_base_path (str): Training view data directory path.
537
+ model_path (str): Model path, './submodules/mast3r/checkpoints/'.
538
+ vis_feat (bool): Visualize and save feature maps.
539
+ vis_key (str): Feature type to visualize(only for mast3r), choose in ["decfeat", "desc"].
540
+ focal_avg (bool): Use averaging focal.
541
+ """
542
+ def __init__(self, images, args, method):
543
+ self.images = images
544
+ self.method = method
545
+ self.device = args.device
546
+ self.feat_type = args.feat_type
547
+ self.feat_dim = args.feat_dim
548
+ self.img_base_path = args.img_base_path
549
+ # self.use_featup = args.use_featup
550
+ self.model_path = args.model_path
551
+ self.vis_feat = args.vis_feat
552
+ self.vis_key = args.vis_key
553
+ self.focal_avg = args.focal_avg
554
+
555
+ def get_dust3r_feat(self, **kw):
556
+ model_path = os.path.join(self.model_path, "DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth")
557
+ model = AsymmetricCroCo3DStereo.from_pretrained(model_path).to(self.device)
558
+ output = inference(kw['pairs'], model, self.device, batch_size=1)
559
+ scene = global_aligner(output, device=self.device, mode=GlobalAlignerMode.PointCloudOptimizer)
560
+ if self.vis_key:
561
+ assert self.vis_key == 'decfeat', f"Expected vis_key to be 'decfeat', but got {self.vis_key}"
562
+ self.vis_decfeat(kw['pairs'], output=output)
563
+
564
+ # del model, output
565
+ # torch.cuda.empty_cache()
566
+
567
+ return scene.stacked_feat
568
+
569
+ def get_mast3r_feat(self, **kw):
570
+ model_path = os.path.join(self.model_path, "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth")
571
+ model = AsymmetricMASt3R.from_pretrained(model_path).to(self.device)
572
+ cache_dir = os.path.join(self.img_base_path, "cache")
573
+ if os.path.exists(cache_dir):
574
+ os.system(f'rm -rf {cache_dir}')
575
+ scene = sparse_global_alignment(kw['train_img_list'], kw['pairs'], cache_dir,
576
+ model, lr1=0.07, niter1=500, lr2=0.014, niter2=200, device=self.device,
577
+ opt_depth='depth' in 'refine', shared_intrinsics=self.focal_avg,
578
+ matching_conf_thr=5.)
579
+ if self.vis_key:
580
+ assert self.vis_key in ['decfeat', 'desc'], f"Expected vis_key to be 'decfeat' or 'desc', but got {self.vis_key}"
581
+ self.vis_decfeat(kw['pairs'], model=model)
582
+
583
+ # del model
584
+ # torch.cuda.empty_cache()
585
+
586
+ return scene.stacked_feat
587
+
588
+ def get_feat(self, feat_type):
589
+ """
590
+ Get features using Probe3D.
591
+ """
592
+ cfg = OmegaConf.load(f"configs/backbone/{feat_type}.yaml")
593
+ model = instantiate(cfg.model, output="dense", return_multilayer=False)
594
+ model = model.to(self.device)
595
+ if 'midas' in feat_type:
596
+ image_norm = adjust_midas_norm(torch.cat([i['img'] for i in self.images])).to(self.device)
597
+ # elif 'clip' in self.feat_type:
598
+ # image_norm = adjust_clip_norm(torch.cat([i['img'] for i in self.images])).to(self.device)
599
+ else:
600
+ image_norm = adjust_norm(torch.cat([i['img'] for i in self.images])).to(self.device)
601
+
602
+ with torch.no_grad():
603
+ feats = model(image_norm).permute(0, 2, 3, 1)
604
+
605
+ # del model
606
+ # torch.cuda.empty_cache()
607
+
608
+ return feats
609
+
610
+ # def get_feat(self, feat_type):
611
+ # """
612
+ # Get features using FeatUp.
613
+ # """
614
+ # original_feat_type = feat_type
615
+ # use_norm = False if 'maskclip' in feat_type else True
616
+ # if 'featup' in original_feat_type:
617
+ # feat_type = feat_type.split('_featup')[0]
618
+ # # feat_upsampler = torch.hub.load("mhamilton723/FeatUp", feat_type, use_norm=use_norm).to(device)
619
+ # feat_upsampler = torch.hub.load("/home/chenyue/.cache/torch/hub/mhamilton723_FeatUp_main/", feat_type, use_norm=use_norm, source='local').to(self.device) ## offline
620
+ # image_norm = adjust_norm(torch.cat([i['img'] for i in self.images])).to(self.device)
621
+ # image_norm = F.interpolate(image_norm, size=(224, 224), mode='bilinear', align_corners=False)
622
+ # if 'featup' in original_feat_type:
623
+ # feats = feat_upsampler(image_norm).permute(0, 2, 3, 1)
624
+ # else:
625
+ # feats = feat_upsampler.model(image_norm).permute(0, 2, 3, 1)
626
+ # return feats
627
+
628
+ def get_iuvrgb(self):
629
+ rgb = torch.cat([i['img'] for i in self.images]).permute(0, 2, 3, 1).to(self.device)
630
+ feats = torch.cat([generate_iuv(*rgb.shape[:-1]).to(self.device), rgb], dim=-1)
631
+ return feats
632
+
633
+ def get_iuv(self):
634
+ rgb = torch.cat([i['img'] for i in self.images]).permute(0, 2, 3, 1).to(self.device)
635
+ feats = generate_iuv(*rgb.shape[:-1]).to(self.device)
636
+ return feats
637
+
638
+ def preprocess(self, feature, feat_dim, vis_feat=False, is_upsample=True):
639
+ """
640
+ Preprocess features by applying PCA, upsampling, and optionally visualizing.
641
+ """
642
+ if feat_dim:
643
+ feature = pca(feature, feat_dim)
644
+ # else:
645
+ # feature_min = feature.min(dim=0, keepdim=True).values.min(dim=1, keepdim=True).values
646
+ # feature_max = feature.max(dim=0, keepdim=True).values.max(dim=1, keepdim=True).values
647
+ # feature = (feature - feature_min) / (feature_max - feature_min + 1e-6)
648
+ # feature = feature - feature.mean(dim=[0,1,2], keepdim=True)
649
+
650
+ torch.cuda.empty_cache()
651
+
652
+ if (feature[0].shape[0:-1] != self.images[0]['true_shape'][0]).all() and is_upsample:
653
+ feature = upsampler(feature, *[s for s in self.images[0]['true_shape'][0]])
654
+
655
+ print(f"Feature map size >>> height: {feature[0].shape[0]}, width: {feature[0].shape[1]}, channels: {feature[0].shape[2]}")
656
+ if vis_feat:
657
+ save_path = visualizer(feature, self.images, self.img_base_path, feat_type=self.feat_type)
658
+ print(f"The encoder feature visualization has been saved at >>>>> {save_path}")
659
+
660
+ return feature
661
+
662
+ def vis_decfeat(self, pairs, **kw):
663
+ """
664
+ Visualize decoder or head(only for mast3r) features.
665
+ """
666
+ if 'output' in kw:
667
+ output = kw['output']
668
+ else:
669
+ output = inference(pairs, kw['model'], self.device, batch_size=1, verbose=False)
670
+ decfeat1 = output['pred1'][self.vis_key].detach()
671
+ decfeat2 = output['pred2'][self.vis_key].detach()
672
+ # decfeat1 = pca(decfeat1, 9)
673
+ # decfeat2 = pca(decfeat2, 9)
674
+ decfeat = torch.stack([decfeat1, decfeat2], dim=1).view(-1, *decfeat1.shape[1:])
675
+ decfeat = torch.cat(torch.chunk(decfeat,2)[::-1], dim=0)
676
+ decfeat = pca(decfeat, 9)
677
+ if (decfeat.shape[1:-1] != self.images[0]['true_shape'][0]).all():
678
+ decfeat = upsampler(decfeat, *[s for s in self.images[0]['true_shape'][0]])
679
+ pair_images = [im for p in pairs[3:] + pairs[:3] for im in p]
680
+ save_path = mv_visualizer(decfeat, pair_images, self.img_base_path,
681
+ feat_type=self.feat_type, file_name=f'{self.vis_key}_pcaall_dim0-9')
682
+ print(f"The decoder feature visualization has been saved at >>>>> {save_path}")
683
+
684
+ def forward(self, **kw):
685
+ feat_dim = self.feat_dim
686
+ vis_feat = self.vis_feat and len(self.feat_type) == 1
687
+ is_upsample = len(self.feat_type) == 1
688
+
689
+ all_feats = {}
690
+ for feat_type in self.feat_type:
691
+ if feat_type == self.method:
692
+ feats = kw['scene'].stacked_feat
693
+ elif feat_type in ['dust3r', 'mast3r']:
694
+ feats = getattr(self, f"get_{feat_type}_feat")(**kw)
695
+ elif feat_type in ['iuv', 'iuvrgb']:
696
+ feats = getattr(self, f"get_{feat_type}")()
697
+ feat_dim = None
698
+ else:
699
+ feats = self.get_feat(feat_type)
700
+
701
+ # feats = to_numpy(self.preprocess(feats))
702
+ all_feats[feat_type] = self.preprocess(feats.detach().clone(), feat_dim, vis_feat, is_upsample)
703
+
704
+ if len(self.feat_type) > 1:
705
+ all_feats = {k: (v - v.min()) / (v.max() - v.min()) for k, v in all_feats.items()}
706
+
707
+ target_size = tuple(s // 16 for s in self.images[0]['true_shape'][0][:2])
708
+ tmp_feats = []
709
+ kickoff = []
710
+
711
+ for k, v in all_feats.items():
712
+ if k in ['iuv', 'iuvrgb']:
713
+ # self.feat_dim -= v.shape[-1]
714
+ kickoff.append(v)
715
+ else:
716
+ if v.shape[1:3] != target_size:
717
+ v = F.interpolate(v.permute(0, 3, 1, 2), size=target_size,
718
+ mode='bilinear', align_corners=False).permute(0, 2, 3, 1)
719
+ tmp_feats.append(v)
720
+
721
+ feats = self.preprocess(torch.cat(tmp_feats, dim=-1), self.feat_dim, self.vis_feat and not kickoff)
722
+
723
+ if kickoff:
724
+ feats = torch.cat([feats] + kickoff, dim=-1)
725
+ feats = self.preprocess(feats, self.feat_dim, self.vis_feat, is_upsample=False)
726
+
727
+ else:
728
+ feats = all_feats[self.feat_type[0]]
729
+
730
+ torch.cuda.empty_cache()
731
+ return to_numpy(feats)
732
+
733
+ def __call__(self, **kw):
734
+ return self.forward(**kw)
735
+
736
+
737
+ class InitMethod:
738
+ """
739
+ Initialize pointmap and camera param via DUSt3R or MASt3R.
740
+ """
741
+ def __init__(self, args):
742
+ self.method = args.method
743
+ self.n_views = args.n_views
744
+ self.device = args.device
745
+ self.img_base_path = args.img_base_path
746
+ self.focal_avg = args.focal_avg
747
+ self.tsdf_thresh = args.tsdf_thresh
748
+ self.min_conf_thr = args.min_conf_thr
749
+ if self.method == 'dust3r':
750
+ self.model_path = os.path.join(args.model_path, "DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth")
751
+ else:
752
+ self.model_path = os.path.join(args.model_path, "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth")
753
+
754
+ def get_dust3r(self):
755
+ return AsymmetricCroCo3DStereo.from_pretrained(self.model_path).to(self.device)
756
+
757
+ def get_mast3r(self):
758
+ return AsymmetricMASt3R.from_pretrained(self.model_path).to(self.device)
759
+
760
+ def infer_dust3r(self, **kw):
761
+ output = inference(kw['pairs'], kw['model'], self.device, batch_size=1)
762
+ scene = global_aligner(output, device=self.device, mode=GlobalAlignerMode.PointCloudOptimizer)
763
+ loss = compute_global_alignment(scene=scene, init="mst", niter=300, schedule='linear', lr=0.01,
764
+ focal_avg=self.focal_avg, known_focal=kw.get('known_focal', None))
765
+ scene = scene.clean_pointcloud()
766
+ return scene
767
+
768
+ def infer_mast3r(self, **kw):
769
+ cache_dir = os.path.join(self.img_base_path, "cache")
770
+ if os.path.exists(cache_dir):
771
+ os.system(f'rm -rf {cache_dir}')
772
+
773
+ scene = sparse_global_alignment(kw['train_img_list'], kw['pairs'], cache_dir,
774
+ kw['model'], lr1=0.07, niter1=500, lr2=0.014, niter2=200, device=self.device,
775
+ opt_depth='depth' in 'refine', shared_intrinsics=self.focal_avg,
776
+ matching_conf_thr=5.)
777
+ return scene
778
+
779
+ def get_dust3r_info(self, scene):
780
+ imgs = to_numpy(scene.imgs)
781
+ focals = scene.get_focals()
782
+ poses = to_numpy(scene.get_im_poses())
783
+ pts3d = to_numpy(scene.get_pts3d())
784
+ # pts3d = to_numpy(scene.get_planes3d())
785
+ scene.min_conf_thr = float(scene.conf_trf(torch.tensor(1.0)))
786
+ confidence_masks = to_numpy(scene.get_masks())
787
+ intrinsics = to_numpy(scene.get_intrinsics())
788
+ return imgs, focals, poses, intrinsics, pts3d, confidence_masks
789
+
790
+ def get_mast3r_info(self, scene):
791
+ imgs = to_numpy(scene.imgs)
792
+ focals = scene.get_focals()[:,None]
793
+ poses = to_numpy(scene.get_im_poses())
794
+ intrinsics = to_numpy(scene.intrinsics)
795
+ tsdf = TSDFPostProcess(scene, TSDF_thresh=self.tsdf_thresh)
796
+ pts3d, _, confs = to_numpy(tsdf.get_dense_pts3d(clean_depth=True))
797
+ pts3d = [arr.reshape((*imgs[0].shape[:2], 3)) for arr in pts3d]
798
+ confidence_masks = np.array(to_numpy([c > self.min_conf_thr for c in confs]))
799
+ return imgs, focals, poses, intrinsics, pts3d, confidence_masks
800
+
801
+ def get_dust3r_depth(self, scene):
802
+ return to_numpy(scene.get_depthmaps())
803
+
804
+ def get_mast3r_depth(self, scene):
805
+ imgs = to_numpy(scene.imgs)
806
+ tsdf = TSDFPostProcess(scene, TSDF_thresh=self.tsdf_thresh)
807
+ _, depthmaps, _ = to_numpy(tsdf.get_dense_pts3d(clean_depth=True))
808
+ depthmaps = [arr.reshape((*imgs[0].shape[:2], 3)) for arr in depthmaps]
809
+ return depthmaps
810
+
811
+ def get_model(self):
812
+ return getattr(self, f"get_{self.method}")()
813
+
814
+ def infer(self, **kw):
815
+ return getattr(self, f"infer_{self.method}")(**kw)
816
+
817
+ def get_info(self, scene):
818
+ return getattr(self, f"get_{self.method}_info")(scene)
819
+
820
+ def get_depth(self, scene):
821
+ return getattr(self, f"get_{self.method}_depth")(scene)
822
+
823
+
824
+
825
+
826
+
827
+
utils/general_utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import torch
13
+ import sys
14
+ from datetime import datetime
15
+ import numpy as np
16
+ import random
17
+
18
+ def inverse_sigmoid(x):
19
+ return torch.log(x/(1-x))
20
+
21
+ def PILtoTorch(pil_image, resolution):
22
+ resized_image_PIL = pil_image.resize(resolution)
23
+ resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
24
+ if len(resized_image.shape) == 3:
25
+ return resized_image.permute(2, 0, 1)
26
+ else:
27
+ return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
28
+
29
+ def get_expon_lr_func(
30
+ lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
31
+ ):
32
+ """
33
+ Copied from Plenoxels
34
+
35
+ Continuous learning rate decay function. Adapted from JaxNeRF
36
+ The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
37
+ is log-linearly interpolated elsewhere (equivalent to exponential decay).
38
+ If lr_delay_steps>0 then the learning rate will be scaled by some smooth
39
+ function of lr_delay_mult, such that the initial learning rate is
40
+ lr_init*lr_delay_mult at the beginning of optimization but will be eased back
41
+ to the normal learning rate when steps>lr_delay_steps.
42
+ :param conf: config subtree 'lr' or similar
43
+ :param max_steps: int, the number of steps during optimization.
44
+ :return HoF which takes step as input
45
+ """
46
+
47
+ def helper(step):
48
+ if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
49
+ # Disable this parameter
50
+ return 0.0
51
+ if lr_delay_steps > 0:
52
+ # A kind of reverse cosine decay.
53
+ delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
54
+ 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
55
+ )
56
+ else:
57
+ delay_rate = 1.0
58
+ t = np.clip(step / max_steps, 0, 1)
59
+ log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
60
+ return delay_rate * log_lerp
61
+
62
+ return helper
63
+
64
+ def strip_lowerdiag(L):
65
+ uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
66
+
67
+ uncertainty[:, 0] = L[:, 0, 0]
68
+ uncertainty[:, 1] = L[:, 0, 1]
69
+ uncertainty[:, 2] = L[:, 0, 2]
70
+ uncertainty[:, 3] = L[:, 1, 1]
71
+ uncertainty[:, 4] = L[:, 1, 2]
72
+ uncertainty[:, 5] = L[:, 2, 2]
73
+ return uncertainty
74
+
75
+ def strip_symmetric(sym):
76
+ return strip_lowerdiag(sym)
77
+
78
+ def build_rotation(r):
79
+ norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
80
+
81
+ q = r / norm[:, None]
82
+
83
+ R = torch.zeros((q.size(0), 3, 3), device='cuda')
84
+
85
+ r = q[:, 0]
86
+ x = q[:, 1]
87
+ y = q[:, 2]
88
+ z = q[:, 3]
89
+
90
+ R[:, 0, 0] = 1 - 2 * (y*y + z*z)
91
+ R[:, 0, 1] = 2 * (x*y - r*z)
92
+ R[:, 0, 2] = 2 * (x*z + r*y)
93
+ R[:, 1, 0] = 2 * (x*y + r*z)
94
+ R[:, 1, 1] = 1 - 2 * (x*x + z*z)
95
+ R[:, 1, 2] = 2 * (y*z - r*x)
96
+ R[:, 2, 0] = 2 * (x*z - r*y)
97
+ R[:, 2, 1] = 2 * (y*z + r*x)
98
+ R[:, 2, 2] = 1 - 2 * (x*x + y*y)
99
+ return R
100
+
101
+ def build_scaling_rotation(s, r):
102
+ L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
103
+ R = build_rotation(r)
104
+
105
+ L[:,0,0] = s[:,0]
106
+ L[:,1,1] = s[:,1]
107
+ L[:,2,2] = s[:,2]
108
+
109
+ L = R @ L
110
+ return L
111
+
112
+ def safe_state(silent):
113
+ old_f = sys.stdout
114
+ class F:
115
+ def __init__(self, silent):
116
+ self.silent = silent
117
+
118
+ def write(self, x):
119
+ if not self.silent:
120
+ if x.endswith("\n"):
121
+ old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
122
+ else:
123
+ old_f.write(x)
124
+
125
+ def flush(self):
126
+ old_f.flush()
127
+
128
+ sys.stdout = F(silent)
129
+
130
+ random.seed(0)
131
+ np.random.seed(0)
132
+ torch.manual_seed(0)
133
+ torch.cuda.set_device(torch.device("cuda:0"))
utils/graphics_utils.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import torch
13
+ import math
14
+ import numpy as np
15
+ from typing import NamedTuple
16
+
17
+ import torch.nn.functional as F
18
+ from torch import Tensor
19
+
20
+ class BasicPointCloud(NamedTuple):
21
+ points : np.array
22
+ colors : np.array
23
+ normals : np.array
24
+ features: np.array
25
+
26
+ def geom_transform_points(points, transf_matrix):
27
+ P, _ = points.shape
28
+ ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
29
+ points_hom = torch.cat([points, ones], dim=1)
30
+ points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
31
+
32
+ denom = points_out[..., 3:] + 0.0000001
33
+ return (points_out[..., :3] / denom).squeeze(dim=0)
34
+
35
+ def getWorld2View(R, t):
36
+ Rt = np.zeros((4, 4))
37
+ Rt[:3, :3] = R.transpose()
38
+ Rt[:3, 3] = t
39
+ Rt[3, 3] = 1.0
40
+ return np.float32(Rt)
41
+
42
+ def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
43
+ Rt = np.zeros((4, 4))
44
+ Rt[:3, :3] = R.transpose()
45
+ Rt[:3, 3] = t
46
+ Rt[3, 3] = 1.0
47
+
48
+ C2W = np.linalg.inv(Rt)
49
+ cam_center = C2W[:3, 3]
50
+ cam_center = (cam_center + translate) * scale
51
+ C2W[:3, 3] = cam_center
52
+ Rt = np.linalg.inv(C2W)
53
+ return np.float32(Rt)
54
+
55
+ def getWorld2View2_torch(R, t, translate=torch.tensor([0.0, 0.0, 0.0]), scale=1.0):
56
+ translate = torch.tensor(translate, dtype=torch.float32)
57
+
58
+ # Initialize the transformation matrix
59
+ Rt = torch.zeros((4, 4), dtype=torch.float32)
60
+ Rt[:3, :3] = R.t() # Transpose of R
61
+ Rt[:3, 3] = t
62
+ Rt[3, 3] = 1.0
63
+
64
+ # Compute the inverse to get the camera-to-world transformation
65
+ C2W = torch.linalg.inv(Rt)
66
+ cam_center = C2W[:3, 3]
67
+ cam_center = (cam_center + translate) * scale
68
+ C2W[:3, 3] = cam_center
69
+
70
+ # Invert again to get the world-to-view transformation
71
+ Rt = torch.linalg.inv(C2W)
72
+
73
+ return Rt
74
+
75
+ def getProjectionMatrix(znear, zfar, fovX, fovY):
76
+ tanHalfFovY = math.tan((fovY / 2))
77
+ tanHalfFovX = math.tan((fovX / 2))
78
+
79
+ top = tanHalfFovY * znear
80
+ bottom = -top
81
+ right = tanHalfFovX * znear
82
+ left = -right
83
+
84
+ P = torch.zeros(4, 4)
85
+
86
+ z_sign = 1.0
87
+
88
+ P[0, 0] = 2.0 * znear / (right - left)
89
+ P[1, 1] = 2.0 * znear / (top - bottom)
90
+ P[0, 2] = (right + left) / (right - left)
91
+ P[1, 2] = (top + bottom) / (top - bottom)
92
+ P[3, 2] = z_sign
93
+ P[2, 2] = z_sign * zfar / (zfar - znear)
94
+ P[2, 3] = -(zfar * znear) / (zfar - znear)
95
+ return P
96
+
97
+ def fov2focal(fov, pixels):
98
+ return pixels / (2 * math.tan(fov / 2))
99
+
100
+ def focal2fov(focal, pixels):
101
+ return 2*math.atan(pixels/(2*focal))
102
+
103
+ def resize_render(view, size=None):
104
+ image_size = size if size is not None else max(view.image_width, view.image_height)
105
+ view.original_image = torch.zeros((3, image_size, image_size), device=view.original_image.device)
106
+ focal_length_x = fov2focal(view.FoVx, view.image_width)
107
+ focal_length_y = fov2focal(view.FoVy, view.image_height)
108
+ view.image_width = image_size
109
+ view.image_height = image_size
110
+ view.FoVx = focal2fov(focal_length_x, image_size)
111
+ view.FoVy = focal2fov(focal_length_y, image_size)
112
+ return view
113
+
114
+ def make_video_divisble(
115
+ video: torch.Tensor | np.ndarray, block_size=16
116
+ ) -> torch.Tensor | np.ndarray:
117
+ H, W = video.shape[1:3]
118
+ H_new = H - H % block_size
119
+ W_new = W - W % block_size
120
+ return video[:, :H_new, :W_new]
121
+
122
+
123
+ def depth_to_points(
124
+ depths: Tensor, camtoworlds: Tensor, Ks: Tensor, z_depth: bool = True
125
+ ) -> Tensor:
126
+ """Convert depth maps to 3D points
127
+
128
+ Args:
129
+ depths: Depth maps [..., H, W, 1]
130
+ camtoworlds: Camera-to-world transformation matrices [..., 4, 4]
131
+ Ks: Camera intrinsics [..., 3, 3]
132
+ z_depth: Whether the depth is in z-depth (True) or ray depth (False)
133
+
134
+ Returns:
135
+ points: 3D points in the world coordinate system [..., H, W, 3]
136
+ """
137
+ assert depths.shape[-1] == 1, f"Invalid depth shape: {depths.shape}"
138
+ assert camtoworlds.shape[-2:] == (
139
+ 4,
140
+ 4,
141
+ ), f"Invalid viewmats shape: {camtoworlds.shape}"
142
+ assert Ks.shape[-2:] == (3, 3), f"Invalid Ks shape: {Ks.shape}"
143
+ assert (
144
+ depths.shape[:-3] == camtoworlds.shape[:-2] == Ks.shape[:-2]
145
+ ), f"Shape mismatch! depths: {depths.shape}, viewmats: {camtoworlds.shape}, Ks: {Ks.shape}"
146
+
147
+ device = depths.device
148
+ height, width = depths.shape[-3:-1]
149
+
150
+ x, y = torch.meshgrid(
151
+ torch.arange(width, device=device),
152
+ torch.arange(height, device=device),
153
+ indexing="xy",
154
+ ) # [H, W]
155
+
156
+ fx = Ks[..., 0, 0] # [...]
157
+ fy = Ks[..., 1, 1] # [...]
158
+ cx = Ks[..., 0, 2] # [...]
159
+ cy = Ks[..., 1, 2] # [...]
160
+
161
+ # camera directions in camera coordinates
162
+ camera_dirs = F.pad(
163
+ torch.stack(
164
+ [
165
+ (x - cx[..., None, None] + 0.5) / fx[..., None, None],
166
+ (y - cy[..., None, None] + 0.5) / fy[..., None, None],
167
+ ],
168
+ dim=-1,
169
+ ),
170
+ (0, 1),
171
+ value=1.0,
172
+ ) # [..., H, W, 3]
173
+
174
+ # ray directions in world coordinates
175
+ directions = torch.einsum(
176
+ "...ij,...hwj->...hwi", camtoworlds[..., :3, :3], camera_dirs
177
+ ) # [..., H, W, 3]
178
+ origins = camtoworlds[..., :3, -1] # [..., 3]
179
+
180
+ if not z_depth:
181
+ directions = F.normalize(directions, dim=-1)
182
+
183
+ points = origins[..., None, None, :] + depths * directions
184
+ return points
185
+
186
+
187
+ def depth_to_normal(
188
+ depths: Tensor, camtoworlds: Tensor, Ks: Tensor, z_depth: bool = True
189
+ ) -> Tensor:
190
+ """Convert depth maps to surface normals
191
+
192
+ Args:
193
+ depths: Depth maps [..., H, W, 1]
194
+ camtoworlds: Camera-to-world transformation matrices [..., 4, 4]
195
+ Ks: Camera intrinsics [..., 3, 3]
196
+ z_depth: Whether the depth is in z-depth (True) or ray depth (False)
197
+
198
+ Returns:
199
+ normals: Surface normals in the world coordinate system [..., H, W, 3]
200
+ """
201
+ points = depth_to_points(depths, camtoworlds, Ks, z_depth=z_depth) # [..., H, W, 3]
202
+ dx = torch.cat(
203
+ [points[..., 2:, 1:-1, :] - points[..., :-2, 1:-1, :]], dim=-3
204
+ ) # [..., H-2, W-2, 3]
205
+ dy = torch.cat(
206
+ [points[..., 1:-1, 2:, :] - points[..., 1:-1, :-2, :]], dim=-2
207
+ ) # [..., H-2, W-2, 3]
208
+ normals = F.normalize(torch.cross(dx, dy, dim=-1), dim=-1) # [..., H-2, W-2, 3]
209
+ normals = F.pad(normals, (0, 0, 1, 1, 1, 1), value=0.0) # [..., H, W, 3]
210
+ return normals
utils/image_utils.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import torch
13
+
14
+ def mse(img1, img2):
15
+ return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
16
+
17
+ def psnr(img1, img2):
18
+ mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
19
+ return 20 * torch.log10(1.0 / torch.sqrt(mse))
20
+
21
+ def masked_psnr(img1, img2, mask):
22
+ mse = ((((img1 - img2)) ** 2) * mask).sum() / (3. * mask.sum())
23
+ return 20 * torch.log10(1.0 / torch.sqrt(mse))
24
+
25
+
26
+ def accuracy_torch(gt_points, rec_points, gt_normals=None, rec_normals=None, batch_size=5000):
27
+ n_points = rec_points.shape[0]
28
+ all_distances = []
29
+ all_indices = []
30
+
31
+ for i in range(0, n_points, batch_size):
32
+ end_idx = min(i + batch_size, n_points)
33
+ batch_points = rec_points[i:end_idx]
34
+
35
+ distances = torch.cdist(batch_points, gt_points) # (batch_size, M)
36
+ batch_distances, batch_indices = torch.min(distances, dim=1) # (batch_size,)
37
+
38
+ all_distances.append(batch_distances)
39
+ all_indices.append(batch_indices)
40
+
41
+ distances = torch.cat(all_distances)
42
+ indices = torch.cat(all_indices)
43
+
44
+ acc = torch.mean(distances)
45
+ acc_median = torch.median(distances)
46
+
47
+ if gt_normals is not None and rec_normals is not None:
48
+ normal_dot = torch.sum(gt_normals[indices] * rec_normals, dim=-1)
49
+ normal_dot = torch.abs(normal_dot)
50
+ return acc, acc_median, torch.mean(normal_dot), torch.median(normal_dot)
51
+
52
+ return acc, acc_median
53
+
54
+ def completion_torch(gt_points, rec_points, gt_normals=None, rec_normals=None, batch_size=5000):
55
+
56
+ n_points = gt_points.shape[0]
57
+ all_distances = []
58
+ all_indices = []
59
+
60
+ for i in range(0, n_points, batch_size):
61
+ end_idx = min(i + batch_size, n_points)
62
+ batch_points = gt_points[i:end_idx]
63
+
64
+ distances = torch.cdist(batch_points, rec_points) # (batch_size, M)
65
+ batch_distances, batch_indices = torch.min(distances, dim=1) # (batch_size,)
66
+
67
+ all_distances.append(batch_distances)
68
+ all_indices.append(batch_indices)
69
+
70
+ distances = torch.cat(all_distances)
71
+ indices = torch.cat(all_indices)
72
+
73
+ comp = torch.mean(distances)
74
+ comp_median = torch.median(distances)
75
+
76
+ if gt_normals is not None and rec_normals is not None:
77
+ normal_dot = torch.sum(gt_normals * rec_normals[indices], dim=-1)
78
+ normal_dot = torch.abs(normal_dot)
79
+ return comp, comp_median, torch.mean(normal_dot), torch.median(normal_dot)
80
+
81
+ return comp, comp_median
82
+
83
+ def accuracy_per_point(gt_points, rec_points, batch_size=5000):
84
+ n_points = rec_points.shape[0]
85
+ all_distances = []
86
+ all_indices = []
87
+
88
+ for i in range(0, n_points, batch_size):
89
+ end_idx = min(i + batch_size, n_points)
90
+ batch_points = rec_points[i:end_idx]
91
+
92
+ distances = torch.cdist(batch_points, gt_points) # (batch_size, M)
93
+ batch_distances, batch_indices = torch.min(distances, dim=1) # (batch_size,)
94
+
95
+ all_distances.append(batch_distances)
96
+ all_indices.append(batch_indices)
97
+
98
+ distances = torch.cat(all_distances)
99
+ return distances
100
+
101
+ def completion_per_point(gt_points, rec_points, batch_size=5000):
102
+
103
+ n_points = gt_points.shape[0]
104
+ all_distances = []
105
+ all_indices = []
106
+
107
+ for i in range(0, n_points, batch_size):
108
+ end_idx = min(i + batch_size, n_points)
109
+ batch_points = gt_points[i:end_idx]
110
+
111
+ distances = torch.cdist(batch_points, rec_points) # (batch_size, M)
112
+ batch_distances, batch_indices = torch.min(distances, dim=1) # (batch_size,)
113
+
114
+ all_distances.append(batch_distances)
115
+ all_indices.append(batch_indices)
116
+
117
+ distances = torch.cat(all_distances)
118
+ return distances
utils/loss_utils.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch.autograd import Variable
15
+ from math import exp
16
+ import einops
17
+
18
+ def l1_loss(network_output, gt):
19
+ return torch.abs((network_output - gt)).mean()
20
+
21
+ def l2_loss(network_output, gt):
22
+ return ((network_output - gt) ** 2).mean()
23
+
24
+ def gaussian(window_size, sigma):
25
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
26
+ return gauss / gauss.sum()
27
+
28
+ def create_window(window_size, channel):
29
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
30
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
31
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
32
+ return window
33
+
34
+ def masked_ssim(img1, img2, mask):
35
+ ssim_map = ssim(img1, img2, get_ssim_map=True)
36
+ return (ssim_map * mask).sum() / (3. * mask.sum())
37
+
38
+
39
+ def ssim(img1, img2, window_size=11, size_average=True, get_ssim_map=False):
40
+ channel = img1.size(-3)
41
+ window = create_window(window_size, channel)
42
+
43
+ if img1.is_cuda:
44
+ window = window.cuda(img1.get_device())
45
+ window = window.type_as(img1)
46
+
47
+ return _ssim(img1, img2, window, window_size, channel, size_average, get_ssim_map)
48
+
49
+ def _ssim(img1, img2, window, window_size, channel, size_average=True, get_ssim_map=False):
50
+ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
51
+ mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
52
+
53
+ mu1_sq = mu1.pow(2)
54
+ mu2_sq = mu2.pow(2)
55
+ mu1_mu2 = mu1 * mu2
56
+
57
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
58
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
59
+ sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
60
+
61
+ C1 = 0.01 ** 2
62
+ C2 = 0.03 ** 2
63
+
64
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
65
+
66
+ if get_ssim_map:
67
+ return ssim_map
68
+ elif size_average:
69
+ return ssim_map.mean()
70
+ else:
71
+ return ssim_map.mean(1).mean(1).mean(1)
72
+
73
+
74
+ # --- Projections ---
75
+
76
+ def homogenize_points(points):
77
+ """Append a '1' along the final dimension of the tensor (i.e. convert xyz->xyz1)"""
78
+ return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
79
+
80
+
81
+ def normalize_homogenous_points(points):
82
+ """Normalize the point vectors"""
83
+ return points / points[..., -1:]
84
+
85
+
86
+ def pixel_space_to_camera_space(pixel_space_points, depth, intrinsics):
87
+ """
88
+ Convert pixel space points to camera space points.
89
+
90
+ Args:
91
+ pixel_space_points (torch.Tensor): Pixel space points with shape (h, w, 2)
92
+ depth (torch.Tensor): Depth map with shape (b, v, h, w, 1)
93
+ intrinsics (torch.Tensor): Camera intrinsics with shape (b, v, 3, 3)
94
+
95
+ Returns:
96
+ torch.Tensor: Camera space points with shape (b, v, h, w, 3).
97
+ """
98
+ pixel_space_points = homogenize_points(pixel_space_points)
99
+ camera_space_points = torch.einsum('b v i j , h w j -> b v h w i', intrinsics.inverse(), pixel_space_points)
100
+ camera_space_points = camera_space_points * depth
101
+ return camera_space_points
102
+
103
+
104
+ def camera_space_to_world_space(camera_space_points, c2w):
105
+ """
106
+ Convert camera space points to world space points.
107
+
108
+ Args:
109
+ camera_space_points (torch.Tensor): Camera space points with shape (b, v, h, w, 3)
110
+ c2w (torch.Tensor): Camera to world extrinsics matrix with shape (b, v, 4, 4)
111
+
112
+ Returns:
113
+ torch.Tensor: World space points with shape (b, v, h, w, 3).
114
+ """
115
+ camera_space_points = homogenize_points(camera_space_points)
116
+ world_space_points = torch.einsum('b v i j , b v h w j -> b v h w i', c2w, camera_space_points)
117
+ return world_space_points[..., :3]
118
+
119
+
120
+ def camera_space_to_pixel_space(camera_space_points, intrinsics):
121
+ """
122
+ Convert camera space points to pixel space points.
123
+
124
+ Args:
125
+ camera_space_points (torch.Tensor): Camera space points with shape (b, v1, v2, h, w, 3)
126
+ c2w (torch.Tensor): Camera to world extrinsics matrix with shape (b, v2, 3, 3)
127
+
128
+ Returns:
129
+ torch.Tensor: World space points with shape (b, v1, v2, h, w, 2).
130
+ """
131
+ camera_space_points = normalize_homogenous_points(camera_space_points)
132
+ pixel_space_points = torch.einsum('b u i j , b v u h w j -> b v u h w i', intrinsics, camera_space_points)
133
+ return pixel_space_points[..., :2]
134
+
135
+
136
+ def world_space_to_camera_space(world_space_points, c2w):
137
+ """
138
+ Convert world space points to pixel space points.
139
+
140
+ Args:
141
+ world_space_points (torch.Tensor): World space points with shape (b, v1, h, w, 3)
142
+ c2w (torch.Tensor): Camera to world extrinsics matrix with shape (b, v2, 4, 4)
143
+
144
+ Returns:
145
+ torch.Tensor: Camera space points with shape (b, v1, v2, h, w, 3).
146
+ """
147
+ world_space_points = homogenize_points(world_space_points)
148
+ camera_space_points = torch.einsum('b u i j , b v h w j -> b v u h w i', c2w.inverse(), world_space_points)
149
+ return camera_space_points[..., :3]
150
+
151
+
152
+ def unproject_depth(depth, intrinsics, c2w):
153
+ """
154
+ Turn the depth map into a 3D point cloud in world space
155
+
156
+ Args:
157
+ depth: (b, v, h, w, 1)
158
+ intrinsics: (b, v, 3, 3)
159
+ c2w: (b, v, 4, 4)
160
+
161
+ Returns:
162
+ torch.Tensor: World space points with shape (b, v, h, w, 3).
163
+ """
164
+
165
+ # Compute indices of pixels
166
+ h, w = depth.shape[-3], depth.shape[-2]
167
+ x_grid, y_grid = torch.meshgrid(
168
+ torch.arange(w, device=depth.device, dtype=torch.float32),
169
+ torch.arange(h, device=depth.device, dtype=torch.float32),
170
+ indexing='xy'
171
+ ) # (h, w), (h, w)
172
+
173
+ # Compute coordinates of pixels in camera space
174
+ pixel_space_points = torch.stack((x_grid, y_grid), dim=-1) # (..., h, w, 2)
175
+ camera_points = pixel_space_to_camera_space(pixel_space_points, depth, intrinsics) # (..., h, w, 3)
176
+
177
+ # Convert points to world space
178
+ world_points = camera_space_to_world_space(camera_points, c2w) # (..., h, w, 3)
179
+
180
+ return world_points
181
+
182
+
183
+ @torch.no_grad()
184
+ def calculate_in_frustum_mask(depth_1, intrinsics_1, c2w_1, depth_2, intrinsics_2, c2w_2, atol=1e-2):
185
+ """
186
+ A function that takes in the depth, intrinsics and c2w matrices of two sets
187
+ of views, and then works out which of the pixels in the first set of views
188
+ has a direct corresponding pixel in any of views in the second set
189
+
190
+ Args:
191
+ depth_1: (b, v1, h, w)
192
+ intrinsics_1: (b, v1, 3, 3)
193
+ c2w_1: (b, v1, 4, 4)
194
+ depth_2: (b, v2, h, w)
195
+ intrinsics_2: (b, v2, 3, 3)
196
+ c2w_2: (b, v2, 4, 4)
197
+
198
+ Returns:
199
+ torch.Tensor: Camera space points with shape (b, v1, h, w).
200
+ """
201
+
202
+ _, v1, h, w = depth_1.shape
203
+ _, v2, _, _ = depth_2.shape
204
+
205
+ # Unproject the depth to get the 3D points in world space
206
+ points_3d = unproject_depth(depth_1[..., None], intrinsics_1, c2w_1) # (b, v1, h, w, 3)
207
+
208
+ # Project the 3D points into the pixel space of all the second views simultaneously
209
+ camera_points = world_space_to_camera_space(points_3d, c2w_2) # (b, v1, v2, h, w, 3)
210
+ points_2d = camera_space_to_pixel_space(camera_points, intrinsics_2) # (b, v1, v2, h, w, 2)
211
+
212
+ # Calculate the depth of each point
213
+ rendered_depth = camera_points[..., 2] # (b, v1, v2, h, w)
214
+
215
+ # We use three conditions to determine if a point should be masked
216
+
217
+ # Condition 1: Check if the points are in the frustum of any of the v2 views
218
+ in_frustum_mask = (
219
+ (points_2d[..., 0] > 0) &
220
+ (points_2d[..., 0] < w) &
221
+ (points_2d[..., 1] > 0) &
222
+ (points_2d[..., 1] < h)
223
+ ) # (b, v1, v2, h, w)
224
+ in_frustum_mask = in_frustum_mask.any(dim=-3) # (b, v1, h, w)
225
+
226
+ # Condition 2: Check if the points have non-zero (i.e. valid) depth in the input view
227
+ non_zero_depth = depth_1 > 1e-6
228
+
229
+ # Condition 3: Check if the points have matching depth to any of the v2
230
+ # views torch.nn.functional.grid_sample expects the input coordinates to
231
+ # be normalized to the range [-1, 1], so we normalize first
232
+ points_2d[..., 0] /= w
233
+ points_2d[..., 1] /= h
234
+ points_2d = points_2d * 2 - 1
235
+ matching_depth = torch.ones_like(rendered_depth, dtype=torch.bool)
236
+ for b in range(depth_1.shape[0]):
237
+ for i in range(v1):
238
+ for j in range(v2):
239
+ depth = einops.rearrange(depth_2[b, j], 'h w -> 1 1 h w')
240
+ coords = einops.rearrange(points_2d[b, i, j], 'h w c -> 1 h w c')
241
+ sampled_depths = torch.nn.functional.grid_sample(depth, coords, align_corners=False)[0, 0]
242
+ matching_depth[b, i, j] = torch.isclose(rendered_depth[b, i, j], sampled_depths, atol=atol)
243
+
244
+ matching_depth = matching_depth.any(dim=-3) # (..., v1, h, w)
245
+
246
+ mask = in_frustum_mask & non_zero_depth & matching_depth
247
+ return mask
utils/pose_utils.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from typing import Tuple
6
+ from utils.stepfun import sample_np, sample
7
+ import scipy
8
+
9
+
10
+ def quad2rotation(q):
11
+ """
12
+ Convert quaternion to rotation in batch. Since all operation in pytorch, support gradient passing.
13
+
14
+ Args:
15
+ quad (tensor, batch_size*4): quaternion.
16
+
17
+ Returns:
18
+ rot_mat (tensor, batch_size*3*3): rotation.
19
+ """
20
+ # bs = quad.shape[0]
21
+ # qr, qi, qj, qk = quad[:, 0], quad[:, 1], quad[:, 2], quad[:, 3]
22
+ # two_s = 2.0 / (quad * quad).sum(-1)
23
+ # rot_mat = torch.zeros(bs, 3, 3).to(quad.get_device())
24
+ # rot_mat[:, 0, 0] = 1 - two_s * (qj**2 + qk**2)
25
+ # rot_mat[:, 0, 1] = two_s * (qi * qj - qk * qr)
26
+ # rot_mat[:, 0, 2] = two_s * (qi * qk + qj * qr)
27
+ # rot_mat[:, 1, 0] = two_s * (qi * qj + qk * qr)
28
+ # rot_mat[:, 1, 1] = 1 - two_s * (qi**2 + qk**2)
29
+ # rot_mat[:, 1, 2] = two_s * (qj * qk - qi * qr)
30
+ # rot_mat[:, 2, 0] = two_s * (qi * qk - qj * qr)
31
+ # rot_mat[:, 2, 1] = two_s * (qj * qk + qi * qr)
32
+ # rot_mat[:, 2, 2] = 1 - two_s * (qi**2 + qj**2)
33
+ # return rot_mat
34
+ if not isinstance(q, torch.Tensor):
35
+ q = torch.tensor(q).cuda()
36
+
37
+ norm = torch.sqrt(
38
+ q[:, 0] * q[:, 0] + q[:, 1] * q[:, 1] + q[:, 2] * q[:, 2] + q[:, 3] * q[:, 3]
39
+ )
40
+ q = q / norm[:, None]
41
+ rot = torch.zeros((q.size(0), 3, 3)).to(q)
42
+ r = q[:, 0]
43
+ x = q[:, 1]
44
+ y = q[:, 2]
45
+ z = q[:, 3]
46
+ rot[:, 0, 0] = 1 - 2 * (y * y + z * z)
47
+ rot[:, 0, 1] = 2 * (x * y - r * z)
48
+ rot[:, 0, 2] = 2 * (x * z + r * y)
49
+ rot[:, 1, 0] = 2 * (x * y + r * z)
50
+ rot[:, 1, 1] = 1 - 2 * (x * x + z * z)
51
+ rot[:, 1, 2] = 2 * (y * z - r * x)
52
+ rot[:, 2, 0] = 2 * (x * z - r * y)
53
+ rot[:, 2, 1] = 2 * (y * z + r * x)
54
+ rot[:, 2, 2] = 1 - 2 * (x * x + y * y)
55
+ return rot
56
+
57
+ def get_camera_from_tensor(inputs):
58
+ """
59
+ Convert quaternion and translation to transformation matrix.
60
+
61
+ """
62
+ if not isinstance(inputs, torch.Tensor):
63
+ inputs = torch.tensor(inputs).cuda()
64
+
65
+ N = len(inputs.shape)
66
+ if N == 1:
67
+ inputs = inputs.unsqueeze(0)
68
+ # quad, T = inputs[:, :4], inputs[:, 4:]
69
+ # # normalize quad
70
+ # quad = F.normalize(quad)
71
+ # R = quad2rotation(quad)
72
+ # RT = torch.cat([R, T[:, :, None]], 2)
73
+ # # Add homogenous row
74
+ # homogenous_row = torch.tensor([0, 0, 0, 1]).cuda()
75
+ # RT = torch.cat([RT, homogenous_row[None, None, :].repeat(N, 1, 1)], 1)
76
+ # if N == 1:
77
+ # RT = RT[0]
78
+ # return RT
79
+
80
+ quad, T = inputs[:, :4], inputs[:, 4:]
81
+ w2c = torch.eye(4).to(inputs).float()
82
+ w2c[:3, :3] = quad2rotation(quad)
83
+ w2c[:3, 3] = T
84
+ return w2c
85
+
86
+ def quadmultiply(q1, q2):
87
+ """
88
+ Multiply two quaternions together using quaternion arithmetic
89
+ """
90
+ # Extract scalar and vector parts of the quaternions
91
+ w1, x1, y1, z1 = q1.unbind(dim=-1)
92
+ w2, x2, y2, z2 = q2.unbind(dim=-1)
93
+ # Calculate the quaternion product
94
+ result_quaternion = torch.stack(
95
+ [
96
+ w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
97
+ w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
98
+ w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
99
+ w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2,
100
+ ],
101
+ dim=-1,
102
+ )
103
+
104
+ return result_quaternion
105
+
106
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
107
+ """
108
+ Returns torch.sqrt(torch.max(0, x))
109
+ but with a zero subgradient where x is 0.
110
+ Source: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#matrix_to_quaternion
111
+ """
112
+ ret = torch.zeros_like(x)
113
+ positive_mask = x > 0
114
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
115
+ return ret
116
+
117
+ def rotation2quad(matrix: torch.Tensor) -> torch.Tensor:
118
+ """
119
+ Convert rotations given as rotation matrices to quaternions.
120
+
121
+ Args:
122
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
123
+
124
+ Returns:
125
+ quaternions with real part first, as tensor of shape (..., 4).
126
+ Source: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#matrix_to_quaternion
127
+ """
128
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
129
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
130
+
131
+ if not isinstance(matrix, torch.Tensor):
132
+ matrix = torch.tensor(matrix).cuda()
133
+
134
+ batch_dim = matrix.shape[:-2]
135
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
136
+ matrix.reshape(batch_dim + (9,)), dim=-1
137
+ )
138
+
139
+ q_abs = _sqrt_positive_part(
140
+ torch.stack(
141
+ [
142
+ 1.0 + m00 + m11 + m22,
143
+ 1.0 + m00 - m11 - m22,
144
+ 1.0 - m00 + m11 - m22,
145
+ 1.0 - m00 - m11 + m22,
146
+ ],
147
+ dim=-1,
148
+ )
149
+ )
150
+
151
+ # we produce the desired quaternion multiplied by each of r, i, j, k
152
+ quat_by_rijk = torch.stack(
153
+ [
154
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
155
+ # `int`.
156
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
157
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
158
+ # `int`.
159
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
160
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
161
+ # `int`.
162
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
163
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
164
+ # `int`.
165
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
166
+ ],
167
+ dim=-2,
168
+ )
169
+
170
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
171
+ # the candidate won't be picked.
172
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
173
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
174
+
175
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
176
+ # forall i; we pick the best-conditioned one (with the largest denominator)
177
+
178
+ return quat_candidates[
179
+ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
180
+ ].reshape(batch_dim + (4,))
181
+
182
+
183
+ def get_tensor_from_camera(RT, Tquad=False):
184
+ """
185
+ Convert transformation matrix to quaternion and translation.
186
+
187
+ """
188
+ # gpu_id = -1
189
+ # if type(RT) == torch.Tensor:
190
+ # if RT.get_device() != -1:
191
+ # gpu_id = RT.get_device()
192
+ # RT = RT.detach().cpu()
193
+ # RT = RT.numpy()
194
+ # from mathutils import Matrix
195
+ #
196
+ # R, T = RT[:3, :3], RT[:3, 3]
197
+ # rot = Matrix(R)
198
+ # quad = rot.to_quaternion()
199
+ # if Tquad:
200
+ # tensor = np.concatenate([T, quad], 0)
201
+ # else:
202
+ # tensor = np.concatenate([quad, T], 0)
203
+ # tensor = torch.from_numpy(tensor).float()
204
+ # if gpu_id != -1:
205
+ # tensor = tensor.to(gpu_id)
206
+ # return tensor
207
+
208
+ if not isinstance(RT, torch.Tensor):
209
+ RT = torch.tensor(RT).cuda()
210
+
211
+ rot = RT[:3, :3].unsqueeze(0).detach()
212
+ quat = rotation2quad(rot).squeeze()
213
+ tran = RT[:3, 3].detach()
214
+
215
+ return torch.cat([quat, tran])
216
+
217
+ def normalize(x):
218
+ return x / np.linalg.norm(x)
219
+
220
+
221
+ def viewmatrix(lookdir, up, position, subtract_position=False):
222
+ """Construct lookat view matrix."""
223
+ vec2 = normalize((lookdir - position) if subtract_position else lookdir)
224
+ vec0 = normalize(np.cross(up, vec2))
225
+ vec1 = normalize(np.cross(vec2, vec0))
226
+ m = np.stack([vec0, vec1, vec2, position], axis=1)
227
+ return m
228
+
229
+
230
+ def poses_avg(poses):
231
+ """New pose using average position, z-axis, and up vector of input poses."""
232
+ position = poses[:, :3, 3].mean(0)
233
+ z_axis = poses[:, :3, 2].mean(0)
234
+ up = poses[:, :3, 1].mean(0)
235
+ cam2world = viewmatrix(z_axis, up, position)
236
+ return cam2world
237
+
238
+
239
+ def focus_point_fn(poses):
240
+ """Calculate nearest point to all focal axes in poses."""
241
+ directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4]
242
+ m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1])
243
+ mt_m = np.transpose(m, [0, 2, 1]) @ m
244
+ focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0]
245
+ return focus_pt
246
+
247
+
248
+ def pad_poses(p):
249
+ """Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1]."""
250
+ bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape)
251
+ return np.concatenate([p[..., :3, :4], bottom], axis=-2)
252
+
253
+ def unpad_poses(p):
254
+ """Remove the homogeneous bottom row from [..., 4, 4] pose matrices."""
255
+ return p[..., :3, :4]
256
+
257
+ def transform_poses_pca(poses):
258
+ """Transforms poses so principal components lie on XYZ axes.
259
+
260
+ Args:
261
+ poses: a (N, 3, 4) array containing the cameras' camera to world transforms.
262
+
263
+ Returns:
264
+ A tuple (poses, transform), with the transformed poses and the applied
265
+ camera_to_world transforms.
266
+ """
267
+ t = poses[:, :3, 3]
268
+ t_mean = t.mean(axis=0)
269
+ t = t - t_mean
270
+
271
+ eigval, eigvec = np.linalg.eig(t.T @ t)
272
+ # Sort eigenvectors in order of largest to smallest eigenvalue.
273
+ inds = np.argsort(eigval)[::-1]
274
+ eigvec = eigvec[:, inds]
275
+ rot = eigvec.T
276
+ if np.linalg.det(rot) < 0:
277
+ rot = np.diag(np.array([1, 1, -1])) @ rot
278
+
279
+ transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1)
280
+ poses_recentered = unpad_poses(transform @ pad_poses(poses))
281
+ transform = np.concatenate([transform, np.eye(4)[3:]], axis=0)
282
+
283
+ # Flip coordinate system if z component of y-axis is negative
284
+ if poses_recentered.mean(axis=0)[2, 1] < 0:
285
+ poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered
286
+ transform = np.diag(np.array([1, -1, -1, 1])) @ transform
287
+
288
+ # Just make sure it's it in the [-1, 1]^3 cube
289
+ scale_factor = 1. / np.max(np.abs(poses_recentered[:, :3, 3]))
290
+ poses_recentered[:, :3, 3] *= scale_factor
291
+ transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform
292
+ return poses_recentered, transform
293
+
294
+
295
+ def recenter_poses(poses: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
296
+ """Recenter poses around the origin."""
297
+ cam2world = poses_avg(poses)
298
+ transform = np.linalg.inv(pad_poses(cam2world))
299
+ poses = transform @ pad_poses(poses)
300
+ return unpad_poses(poses), transform
301
+
302
+ def generate_ellipse_path(views, n_frames=600, const_speed=True, z_variation=0., z_phase=0.):
303
+ poses = []
304
+ for view in views:
305
+ tmp_view = np.eye(4)
306
+ tmp_view[:3] = np.concatenate([view.R.T, view.T[:, None]], 1)
307
+ tmp_view = np.linalg.inv(tmp_view)
308
+ tmp_view[:, 1:3] *= -1
309
+ poses.append(tmp_view)
310
+ poses = np.stack(poses, 0)
311
+ poses, transform = transform_poses_pca(poses)
312
+
313
+
314
+ # Calculate the focal point for the path (cameras point toward this).
315
+ center = focus_point_fn(poses)
316
+ # Path height sits at z=0 (in middle of zero-mean capture pattern).
317
+ offset = np.array([center[0] , center[1], 0 ])
318
+ # Calculate scaling for ellipse axes based on input camera positions.
319
+ sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0)
320
+
321
+ # Use ellipse that is symmetric about the focal point in xy.
322
+ low = -sc + offset
323
+ high = sc + offset
324
+ # Optional height variation need not be symmetric
325
+ z_low = np.percentile((poses[:, :3, 3]), 10, axis=0)
326
+ z_high = np.percentile((poses[:, :3, 3]), 90, axis=0)
327
+
328
+
329
+ def get_positions(theta):
330
+ # Interpolate between bounds with trig functions to get ellipse in x-y.
331
+ # Optionally also interpolate in z to change camera height along path.
332
+ return np.stack([
333
+ (low[0] + (high - low)[0] * (np.cos(theta) * .5 + .5)),
334
+ (low[1] + (high - low)[1] * (np.sin(theta) * .5 + .5)),
335
+ z_variation * (z_low[2] + (z_high - z_low)[2] *
336
+ (np.cos(theta + 2 * np.pi * z_phase) * .5 + .5)),
337
+ ], -1)
338
+
339
+ theta = np.linspace(0, 2. * np.pi, n_frames + 1, endpoint=True)
340
+ positions = get_positions(theta)
341
+
342
+ if const_speed:
343
+ # Resample theta angles so that the velocity is closer to constant.
344
+ lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
345
+ theta = sample_np(None, theta, np.log(lengths), n_frames + 1)
346
+ positions = get_positions(theta)
347
+
348
+ # Throw away duplicated last position.
349
+ positions = positions[:-1]
350
+
351
+ # Set path's up vector to axis closest to average of input pose up vectors.
352
+ avg_up = poses[:, :3, 1].mean(0)
353
+ avg_up = avg_up / np.linalg.norm(avg_up)
354
+ ind_up = np.argmax(np.abs(avg_up))
355
+ up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up])
356
+ # up = normalize(poses[:, :3, 1].sum(0))
357
+
358
+ render_poses = []
359
+ for p in positions:
360
+ render_pose = np.eye(4)
361
+ render_pose[:3] = viewmatrix(p - center, up, p)
362
+ render_pose = np.linalg.inv(transform) @ render_pose
363
+ render_pose[:3, 1:3] *= -1
364
+ render_poses.append(np.linalg.inv(render_pose))
365
+ return render_poses
366
+
367
+
368
+
369
+ def generate_spiral_path(poses_arr,
370
+ n_frames: int = 180,
371
+ n_rots: int = 2,
372
+ zrate: float = .5) -> np.ndarray:
373
+ """Calculates a forward facing spiral path for rendering."""
374
+ poses = poses_arr[:, :-2].reshape([-1, 3, 5])
375
+ bounds = poses_arr[:, -2:]
376
+ fix_rotation = np.array([
377
+ [0, -1, 0, 0],
378
+ [1, 0, 0, 0],
379
+ [0, 0, 1, 0],
380
+ [0, 0, 0, 1],
381
+ ], dtype=np.float32)
382
+ poses = poses[:, :3, :4] @ fix_rotation
383
+
384
+ scale = 1. / (bounds.min() * .75)
385
+ poses[:, :3, 3] *= scale
386
+ bounds *= scale
387
+ poses, transform = recenter_poses(poses)
388
+
389
+ close_depth, inf_depth = bounds.min() * .9, bounds.max() * 5.
390
+ dt = .75
391
+ focal = 1 / (((1 - dt) / close_depth + dt / inf_depth))
392
+
393
+ # Get radii for spiral path using 90th percentile of camera positions.
394
+ positions = poses[:, :3, 3]
395
+ radii = np.percentile(np.abs(positions), 90, 0)
396
+ radii = np.concatenate([radii, [1.]])
397
+
398
+ # Generate poses for spiral path.
399
+ render_poses = []
400
+ cam2world = poses_avg(poses)
401
+ up = poses[:, :3, 1].mean(0)
402
+ for theta in np.linspace(0., 2. * np.pi * n_rots, n_frames, endpoint=False):
403
+ t = radii * [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]
404
+ position = cam2world @ t
405
+ lookat = cam2world @ [0, 0, -focal, 1.]
406
+ z_axis = position - lookat
407
+ render_pose = np.eye(4)
408
+ render_pose[:3] = viewmatrix(z_axis, up, position)
409
+ render_pose = np.linalg.inv(transform) @ render_pose
410
+ render_pose[:3, 1:3] *= -1
411
+ render_pose[:3, 3] /= scale
412
+ render_poses.append(np.linalg.inv(render_pose))
413
+ render_poses = np.stack(render_poses, axis=0)
414
+ return render_poses
415
+
416
+
417
+
418
+ def generate_interpolated_path(
419
+ views,
420
+ n_interp,
421
+ spline_degree = 5,
422
+ smoothness = 0.03,
423
+ rot_weight = 0.1,
424
+ lock_up = False,
425
+ fixed_up_vector = None,
426
+ lookahead_i = None,
427
+ frames_per_colmap = None,
428
+ const_speed = False,
429
+ n_buffer = None,
430
+ periodic = False,
431
+ n_interp_as_total = False,
432
+ ):
433
+ """Creates a smooth spline path between input keyframe camera poses.
434
+
435
+ Spline is calculated with poses in format (position, lookat-point, up-point).
436
+ Args:
437
+ poses: (n, 3, 4) array of input pose keyframes.
438
+ n_interp: returned path will have n_interp * (n - 1) total poses.
439
+ spline_degree: polynomial degree of B-spline.
440
+ smoothness: parameter for spline smoothing, 0 forces exact interpolation.
441
+ rot_weight: relative weighting of rotation/translation in spline solve.
442
+ lock_up: if True, forced to use given Up and allow Lookat to vary.
443
+ fixed_up_vector: replace the interpolated `up` with a fixed vector.
444
+ lookahead_i: force the look direction to look at the pose `i` frames ahead.
445
+ frames_per_colmap: conversion factor for the desired average velocity.
446
+ const_speed: renormalize spline to have constant delta between each pose.
447
+ n_buffer: Number of buffer frames to insert at the start and end of the
448
+ path. Helps keep the ends of a spline path straight.
449
+ periodic: make the spline path periodic (perfect loop).
450
+ n_interp_as_total: use n_interp as total number of poses in path rather than
451
+ the number of poses to interpolate between each input.
452
+
453
+ Returns:
454
+ Array of new camera poses with shape (n_interp * (n - 1), 3, 4), or
455
+ (n_interp, 3, 4) if n_interp_as_total is set.
456
+ """
457
+ poses = []
458
+ for view in views:
459
+ tmp_view = np.eye(4)
460
+ tmp_view[:3] = np.concatenate([view.R.T, view.T[:, None]], 1)
461
+ tmp_view = np.linalg.inv(tmp_view)
462
+ tmp_view[:, 1:3] *= -1
463
+ poses.append(tmp_view)
464
+ poses = np.stack(poses, 0)
465
+
466
+ def poses_to_points(poses, dist):
467
+ """Converts from pose matrices to (position, lookat, up) format."""
468
+ pos = poses[:, :3, -1]
469
+ lookat = poses[:, :3, -1] - dist * poses[:, :3, 2]
470
+ up = poses[:, :3, -1] + dist * poses[:, :3, 1]
471
+ return np.stack([pos, lookat, up], 1)
472
+
473
+ def points_to_poses(points):
474
+ """Converts from (position, lookat, up) format to pose matrices."""
475
+ poses = []
476
+ for i in range(len(points)):
477
+ pos, lookat_point, up_point = points[i]
478
+ if lookahead_i is not None:
479
+ if i + lookahead_i < len(points):
480
+ lookat = pos - points[i + lookahead_i][0]
481
+ else:
482
+ lookat = pos - lookat_point
483
+ up = (up_point - pos) if fixed_up_vector is None else fixed_up_vector
484
+ poses.append(viewmatrix(lookat, up, pos))
485
+ return np.array(poses)
486
+
487
+ def insert_buffer_poses(poses, n_buffer):
488
+ """Insert extra poses at the start and end of the path."""
489
+
490
+ def average_distance(points):
491
+ distances = np.linalg.norm(points[1:] - points[0:-1], axis=-1)
492
+ return np.mean(distances)
493
+
494
+ def shift(pose, dz):
495
+ result = np.copy(pose)
496
+ z = result[:3, 2]
497
+ z /= np.linalg.norm(z)
498
+ # Move along forward-backward axis. -z is forward.
499
+ result[:3, 3] += z * dz
500
+ return result
501
+
502
+ dz = average_distance(poses[:, :3, 3])
503
+ prefix = np.stack([shift(poses[0], (i + 1) * dz) for i in range(n_buffer)])
504
+ prefix = prefix[::-1] # reverse order
505
+ suffix = np.stack(
506
+ [shift(poses[-1], -(i + 1) * dz) for i in range(n_buffer)]
507
+ )
508
+ result = np.concatenate([prefix, poses, suffix])
509
+ return result
510
+
511
+ def remove_buffer_poses(poses, u, n_frames, u_keyframes, n_buffer):
512
+ u_keyframes = u_keyframes[n_buffer:-n_buffer]
513
+ mask = (u >= u_keyframes[0]) & (u <= u_keyframes[-1])
514
+ poses = poses[mask]
515
+ u = u[mask]
516
+ n_frames = len(poses)
517
+ return poses, u, n_frames, u_keyframes
518
+
519
+ def interp(points, u, k, s):
520
+ """Runs multidimensional B-spline interpolation on the input points."""
521
+ sh = points.shape
522
+ pts = np.reshape(points, (sh[0], -1))
523
+ k = min(k, sh[0] - 1)
524
+ tck, u_keyframes = scipy.interpolate.splprep(pts.T, k=k, s=s, per=periodic)
525
+ new_points = np.array(scipy.interpolate.splev(u, tck))
526
+ new_points = np.reshape(new_points.T, (len(u), sh[1], sh[2]))
527
+ return new_points, u_keyframes
528
+
529
+
530
+ if n_buffer is not None:
531
+ poses = insert_buffer_poses(poses, n_buffer)
532
+ points = poses_to_points(poses, dist=rot_weight)
533
+ if n_interp_as_total:
534
+ n_frames = n_interp + 1 # Add extra since final pose is discarded.
535
+ else:
536
+ n_frames = n_interp * (points.shape[0] - 1)
537
+ u = np.linspace(0, 1, n_frames, endpoint=True)
538
+ new_points, u_keyframes = interp(points, u=u, k=spline_degree, s=smoothness)
539
+ poses = points_to_poses(new_points)
540
+ if n_buffer is not None:
541
+ poses, u, n_frames, u_keyframes = remove_buffer_poses(
542
+ poses, u, n_frames, u_keyframes, n_buffer
543
+ )
544
+ # poses, transform = transform_poses_pca(poses)
545
+ if frames_per_colmap is not None:
546
+ # Recalculate the number of frames to achieve desired average velocity.
547
+ positions = poses[:, :3, -1]
548
+ lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
549
+ total_length_colmap = lengths.sum()
550
+ print('old n_frames:', n_frames)
551
+ print('total_length_colmap:', total_length_colmap)
552
+ n_frames = int(total_length_colmap * frames_per_colmap)
553
+ print('new n_frames:', n_frames)
554
+ u = np.linspace(
555
+ np.min(u_keyframes), np.max(u_keyframes), n_frames, endpoint=True
556
+ )
557
+ new_points, _ = interp(points, u=u, k=spline_degree, s=smoothness)
558
+ poses = points_to_poses(new_points)
559
+
560
+ if const_speed:
561
+ # Resample timesteps so that the velocity is nearly constant.
562
+ positions = poses[:, :3, -1]
563
+ lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
564
+ u = sample(None, u, np.log(lengths), n_frames + 1)
565
+ new_points, _ = interp(points, u=u, k=spline_degree, s=smoothness)
566
+ poses = points_to_poses(new_points)
567
+
568
+ # return poses[:-1], u[:-1], u_keyframes
569
+ return poses[:-1]
570
+
utils/sh_utils.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The PlenOctree Authors.
2
+ # Redistribution and use in source and binary forms, with or without
3
+ # modification, are permitted provided that the following conditions are met:
4
+ #
5
+ # 1. Redistributions of source code must retain the above copyright notice,
6
+ # this list of conditions and the following disclaimer.
7
+ #
8
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
9
+ # this list of conditions and the following disclaimer in the documentation
10
+ # and/or other materials provided with the distribution.
11
+ #
12
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15
+ # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16
+ # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17
+ # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18
+ # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19
+ # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20
+ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21
+ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22
+ # POSSIBILITY OF SUCH DAMAGE.
23
+
24
+ import torch
25
+
26
+ C0 = 0.28209479177387814
27
+ C1 = 0.4886025119029199
28
+ C2 = [
29
+ 1.0925484305920792,
30
+ -1.0925484305920792,
31
+ 0.31539156525252005,
32
+ -1.0925484305920792,
33
+ 0.5462742152960396
34
+ ]
35
+ C3 = [
36
+ -0.5900435899266435,
37
+ 2.890611442640554,
38
+ -0.4570457994644658,
39
+ 0.3731763325901154,
40
+ -0.4570457994644658,
41
+ 1.445305721320277,
42
+ -0.5900435899266435
43
+ ]
44
+ C4 = [
45
+ 2.5033429417967046,
46
+ -1.7701307697799304,
47
+ 0.9461746957575601,
48
+ -0.6690465435572892,
49
+ 0.10578554691520431,
50
+ -0.6690465435572892,
51
+ 0.47308734787878004,
52
+ -1.7701307697799304,
53
+ 0.6258357354491761,
54
+ ]
55
+
56
+
57
+ def eval_sh(deg, sh, dirs):
58
+ """
59
+ Evaluate spherical harmonics at unit directions
60
+ using hardcoded SH polynomials.
61
+ Works with torch/np/jnp.
62
+ ... Can be 0 or more batch dimensions.
63
+ Args:
64
+ deg: int SH deg. Currently, 0-3 supported
65
+ sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
66
+ dirs: jnp.ndarray unit directions [..., 3]
67
+ Returns:
68
+ [..., C]
69
+ """
70
+ assert deg <= 4 and deg >= 0
71
+ coeff = (deg + 1) ** 2
72
+ assert sh.shape[-1] >= coeff
73
+
74
+ result = C0 * sh[..., 0]
75
+ if deg > 0:
76
+ x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
77
+ result = (result -
78
+ C1 * y * sh[..., 1] +
79
+ C1 * z * sh[..., 2] -
80
+ C1 * x * sh[..., 3])
81
+
82
+ if deg > 1:
83
+ xx, yy, zz = x * x, y * y, z * z
84
+ xy, yz, xz = x * y, y * z, x * z
85
+ result = (result +
86
+ C2[0] * xy * sh[..., 4] +
87
+ C2[1] * yz * sh[..., 5] +
88
+ C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
89
+ C2[3] * xz * sh[..., 7] +
90
+ C2[4] * (xx - yy) * sh[..., 8])
91
+
92
+ if deg > 2:
93
+ result = (result +
94
+ C3[0] * y * (3 * xx - yy) * sh[..., 9] +
95
+ C3[1] * xy * z * sh[..., 10] +
96
+ C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
97
+ C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
98
+ C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
99
+ C3[5] * z * (xx - yy) * sh[..., 14] +
100
+ C3[6] * x * (xx - 3 * yy) * sh[..., 15])
101
+
102
+ if deg > 3:
103
+ result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
104
+ C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
105
+ C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
106
+ C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
107
+ C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
108
+ C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
109
+ C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
110
+ C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
111
+ C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
112
+ return result
113
+
114
+ def RGB2SH(rgb):
115
+ return (rgb - 0.5) / C0
116
+
117
+ def SH2RGB(sh):
118
+ return sh * C0 + 0.5
utils/stepfun.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from internal import math
2
+ import numpy as np
3
+ import torch
4
+
5
+
6
+ def searchsorted(a, v):
7
+ """Find indices where v should be inserted into a to maintain order.
8
+
9
+ Args:
10
+ a: tensor, the sorted reference points that we are scanning to see where v
11
+ should lie.
12
+ v: tensor, the query points that we are pretending to insert into a. Does
13
+ not need to be sorted. All but the last dimensions should match or expand
14
+ to those of a, the last dimension can differ.
15
+
16
+ Returns:
17
+ (idx_lo, idx_hi), where a[idx_lo] <= v < a[idx_hi], unless v is out of the
18
+ range [a[0], a[-1]] in which case idx_lo and idx_hi are both the first or
19
+ last index of a.
20
+ """
21
+ i = torch.arange(a.shape[-1], device=a.device)
22
+ v_ge_a = v[..., None, :] >= a[..., :, None]
23
+ idx_lo = torch.max(torch.where(v_ge_a, i[..., :, None], i[..., :1, None]), -2).values
24
+ idx_hi = torch.min(torch.where(~v_ge_a, i[..., :, None], i[..., -1:, None]), -2).values
25
+ return idx_lo, idx_hi
26
+
27
+
28
+ def query(tq, t, y, outside_value=0):
29
+ """Look up the values of the step function (t, y) at locations tq."""
30
+ idx_lo, idx_hi = searchsorted(t, tq)
31
+ yq = torch.where(idx_lo == idx_hi, torch.full_like(idx_hi, outside_value),
32
+ torch.take_along_dim(y, idx_lo, dim=-1))
33
+ return yq
34
+
35
+
36
+ def inner_outer(t0, t1, y1):
37
+ """Construct inner and outer measures on (t1, y1) for t0."""
38
+ cy1 = torch.cat([torch.zeros_like(y1[..., :1]),
39
+ torch.cumsum(y1, dim=-1)],
40
+ dim=-1)
41
+ idx_lo, idx_hi = searchsorted(t1, t0)
42
+
43
+ cy1_lo = torch.take_along_dim(cy1, idx_lo, dim=-1)
44
+ cy1_hi = torch.take_along_dim(cy1, idx_hi, dim=-1)
45
+
46
+ y0_outer = cy1_hi[..., 1:] - cy1_lo[..., :-1]
47
+ y0_inner = torch.where(idx_hi[..., :-1] <= idx_lo[..., 1:],
48
+ cy1_lo[..., 1:] - cy1_hi[..., :-1], torch.zeros_like(idx_lo[..., 1:]))
49
+ return y0_inner, y0_outer
50
+
51
+
52
+ def lossfun_outer(t, w, t_env, w_env):
53
+ """The proposal weight should be an upper envelope on the nerf weight."""
54
+ eps = torch.finfo(t.dtype).eps
55
+ # eps = 1e-3
56
+
57
+ _, w_outer = inner_outer(t, t_env, w_env)
58
+ # We assume w_inner <= w <= w_outer. We don't penalize w_inner because it's
59
+ # more effective to pull w_outer up than it is to push w_inner down.
60
+ # Scaled half-quadratic loss that gives a constant gradient at w_outer = 0.
61
+ return (w - w_outer).clamp_min(0) ** 2 / (w + eps)
62
+
63
+
64
+ def weight_to_pdf(t, w):
65
+ """Turn a vector of weights that sums to 1 into a PDF that integrates to 1."""
66
+ eps = torch.finfo(t.dtype).eps
67
+ return w / (t[..., 1:] - t[..., :-1]).clamp_min(eps)
68
+
69
+
70
+ def pdf_to_weight(t, p):
71
+ """Turn a PDF that integrates to 1 into a vector of weights that sums to 1."""
72
+ return p * (t[..., 1:] - t[..., :-1])
73
+
74
+
75
+ def max_dilate(t, w, dilation, domain=(-torch.inf, torch.inf)):
76
+ """Dilate (via max-pooling) a non-negative step function."""
77
+ t0 = t[..., :-1] - dilation
78
+ t1 = t[..., 1:] + dilation
79
+ t_dilate, _ = torch.sort(torch.cat([t, t0, t1], dim=-1), dim=-1)
80
+ t_dilate = torch.clip(t_dilate, *domain)
81
+ w_dilate = torch.max(
82
+ torch.where(
83
+ (t0[..., None, :] <= t_dilate[..., None])
84
+ & (t1[..., None, :] > t_dilate[..., None]),
85
+ w[..., None, :],
86
+ torch.zeros_like(w[..., None, :]),
87
+ ), dim=-1).values[..., :-1]
88
+ return t_dilate, w_dilate
89
+
90
+
91
+ def max_dilate_weights(t,
92
+ w,
93
+ dilation,
94
+ domain=(-torch.inf, torch.inf),
95
+ renormalize=False):
96
+ """Dilate (via max-pooling) a set of weights."""
97
+ eps = torch.finfo(w.dtype).eps
98
+ # eps = 1e-3
99
+
100
+ p = weight_to_pdf(t, w)
101
+ t_dilate, p_dilate = max_dilate(t, p, dilation, domain=domain)
102
+ w_dilate = pdf_to_weight(t_dilate, p_dilate)
103
+ if renormalize:
104
+ w_dilate /= torch.sum(w_dilate, dim=-1, keepdim=True).clamp_min(eps)
105
+ return t_dilate, w_dilate
106
+
107
+
108
+ def integrate_weights(w):
109
+ """Compute the cumulative sum of w, assuming all weight vectors sum to 1.
110
+
111
+ The output's size on the last dimension is one greater than that of the input,
112
+ because we're computing the integral corresponding to the endpoints of a step
113
+ function, not the integral of the interior/bin values.
114
+
115
+ Args:
116
+ w: Tensor, which will be integrated along the last axis. This is assumed to
117
+ sum to 1 along the last axis, and this function will (silently) break if
118
+ that is not the case.
119
+
120
+ Returns:
121
+ cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1
122
+ """
123
+ cw = torch.cumsum(w[..., :-1], dim=-1).clamp_max(1)
124
+ shape = cw.shape[:-1] + (1,)
125
+ # Ensure that the CDF starts with exactly 0 and ends with exactly 1.
126
+ cw0 = torch.cat([torch.zeros(shape, device=cw.device), cw,
127
+ torch.ones(shape, device=cw.device)], dim=-1)
128
+ return cw0
129
+
130
+
131
+ def integrate_weights_np(w):
132
+ """Compute the cumulative sum of w, assuming all weight vectors sum to 1.
133
+
134
+ The output's size on the last dimension is one greater than that of the input,
135
+ because we're computing the integral corresponding to the endpoints of a step
136
+ function, not the integral of the interior/bin values.
137
+
138
+ Args:
139
+ w: Tensor, which will be integrated along the last axis. This is assumed to
140
+ sum to 1 along the last axis, and this function will (silently) break if
141
+ that is not the case.
142
+
143
+ Returns:
144
+ cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1
145
+ """
146
+ cw = np.minimum(1, np.cumsum(w[..., :-1], axis=-1))
147
+ shape = cw.shape[:-1] + (1,)
148
+ # Ensure that the CDF starts with exactly 0 and ends with exactly 1.
149
+ cw0 = np.concatenate([np.zeros(shape), cw,
150
+ np.ones(shape)], axis=-1)
151
+ return cw0
152
+
153
+
154
+ def invert_cdf(u, t, w_logits):
155
+ """Invert the CDF defined by (t, w) at the points specified by u in [0, 1)."""
156
+ # Compute the PDF and CDF for each weight vector.
157
+ w = torch.softmax(w_logits, dim=-1)
158
+ cw = integrate_weights(w)
159
+ # Interpolate into the inverse CDF.
160
+ t_new = math.sorted_interp(u, cw, t)
161
+ return t_new
162
+
163
+
164
+ def invert_cdf_np(u, t, w_logits):
165
+ """Invert the CDF defined by (t, w) at the points specified by u in [0, 1)."""
166
+ # Compute the PDF and CDF for each weight vector.
167
+ w = np.exp(w_logits) / np.exp(w_logits).sum(axis=-1, keepdims=True)
168
+ cw = integrate_weights_np(w)
169
+ # Interpolate into the inverse CDF.
170
+ interp_fn = np.interp
171
+ t_new = interp_fn(u, cw, t)
172
+ return t_new
173
+
174
+
175
+ def sample(rand,
176
+ t,
177
+ w_logits,
178
+ num_samples,
179
+ single_jitter=False,
180
+ deterministic_center=False):
181
+ """Piecewise-Constant PDF sampling from a step function.
182
+
183
+ Args:
184
+ rand: random number generator (or None for `linspace` sampling).
185
+ t: [..., num_bins + 1], bin endpoint coordinates (must be sorted)
186
+ w_logits: [..., num_bins], logits corresponding to bin weights
187
+ num_samples: int, the number of samples.
188
+ single_jitter: bool, if True, jitter every sample along each ray by the same
189
+ amount in the inverse CDF. Otherwise, jitter each sample independently.
190
+ deterministic_center: bool, if False, when `rand` is None return samples that
191
+ linspace the entire PDF. If True, skip the front and back of the linspace
192
+ so that the centers of each PDF interval are returned.
193
+
194
+ Returns:
195
+ t_samples: [batch_size, num_samples].
196
+ """
197
+ eps = torch.finfo(t.dtype).eps
198
+ # eps = 1e-3
199
+
200
+ device = t.device
201
+
202
+ # Draw uniform samples.
203
+ if not rand:
204
+ if deterministic_center:
205
+ pad = 1 / (2 * num_samples)
206
+ u = torch.linspace(pad, 1. - pad - eps, num_samples, device=device)
207
+ else:
208
+ u = torch.linspace(0, 1. - eps, num_samples, device=device)
209
+ u = torch.broadcast_to(u, t.shape[:-1] + (num_samples,))
210
+ else:
211
+ # `u` is in [0, 1) --- it can be zero, but it can never be 1.
212
+ u_max = eps + (1 - eps) / num_samples
213
+ max_jitter = (1 - u_max) / (num_samples - 1) - eps
214
+ d = 1 if single_jitter else num_samples
215
+ u = torch.linspace(0, 1 - u_max, num_samples, device=device) + \
216
+ torch.rand(t.shape[:-1] + (d,), device=device) * max_jitter
217
+
218
+ return invert_cdf(u, t, w_logits)
219
+
220
+
221
+ def sample_np(rand,
222
+ t,
223
+ w_logits,
224
+ num_samples,
225
+ single_jitter=False,
226
+ deterministic_center=False):
227
+ """
228
+ numpy version of sample()
229
+ """
230
+ eps = np.finfo(np.float32).eps
231
+
232
+ # Draw uniform samples.
233
+ if not rand:
234
+ if deterministic_center:
235
+ pad = 1 / (2 * num_samples)
236
+ u = np.linspace(pad, 1. - pad - eps, num_samples)
237
+ else:
238
+ u = np.linspace(0, 1. - eps, num_samples)
239
+ u = np.broadcast_to(u, t.shape[:-1] + (num_samples,))
240
+ else:
241
+ # `u` is in [0, 1) --- it can be zero, but it can never be 1.
242
+ u_max = eps + (1 - eps) / num_samples
243
+ max_jitter = (1 - u_max) / (num_samples - 1) - eps
244
+ d = 1 if single_jitter else num_samples
245
+ u = np.linspace(0, 1 - u_max, num_samples) + \
246
+ np.random.rand(*t.shape[:-1], d) * max_jitter
247
+
248
+ return invert_cdf_np(u, t, w_logits)
249
+
250
+
251
+ def sample_intervals(rand,
252
+ t,
253
+ w_logits,
254
+ num_samples,
255
+ single_jitter=False,
256
+ domain=(-torch.inf, torch.inf)):
257
+ """Sample *intervals* (rather than points) from a step function.
258
+
259
+ Args:
260
+ rand: random number generator (or None for `linspace` sampling).
261
+ t: [..., num_bins + 1], bin endpoint coordinates (must be sorted)
262
+ w_logits: [..., num_bins], logits corresponding to bin weights
263
+ num_samples: int, the number of intervals to sample.
264
+ single_jitter: bool, if True, jitter every sample along each ray by the same
265
+ amount in the inverse CDF. Otherwise, jitter each sample independently.
266
+ domain: (minval, maxval), the range of valid values for `t`.
267
+
268
+ Returns:
269
+ t_samples: [batch_size, num_samples].
270
+ """
271
+ if num_samples <= 1:
272
+ raise ValueError(f'num_samples must be > 1, is {num_samples}.')
273
+
274
+ # Sample a set of points from the step function.
275
+ centers = sample(
276
+ rand,
277
+ t,
278
+ w_logits,
279
+ num_samples,
280
+ single_jitter,
281
+ deterministic_center=True)
282
+
283
+ # The intervals we return will span the midpoints of each adjacent sample.
284
+ mid = (centers[..., 1:] + centers[..., :-1]) / 2
285
+
286
+ # Each first/last fencepost is the reflection of the first/last midpoint
287
+ # around the first/last sampled center. We clamp to the limits of the input
288
+ # domain, provided by the caller.
289
+ minval, maxval = domain
290
+ first = (2 * centers[..., :1] - mid[..., :1]).clamp_min(minval)
291
+ last = (2 * centers[..., -1:] - mid[..., -1:]).clamp_max(maxval)
292
+
293
+ t_samples = torch.cat([first, mid, last], dim=-1)
294
+ return t_samples
295
+
296
+
297
+ def lossfun_distortion(t, w):
298
+ """Compute iint w[i] w[j] |t[i] - t[j]| di dj."""
299
+ # The loss incurred between all pairs of intervals.
300
+ ut = (t[..., 1:] + t[..., :-1]) / 2
301
+ dut = torch.abs(ut[..., :, None] - ut[..., None, :])
302
+ loss_inter = torch.sum(w * torch.sum(w[..., None, :] * dut, dim=-1), dim=-1)
303
+
304
+ # The loss incurred within each individual interval with itself.
305
+ loss_intra = torch.sum(w ** 2 * (t[..., 1:] - t[..., :-1]), dim=-1) / 3
306
+
307
+ return loss_inter + loss_intra
308
+
309
+
310
+ def interval_distortion(t0_lo, t0_hi, t1_lo, t1_hi):
311
+ """Compute mean(abs(x-y); x in [t0_lo, t0_hi], y in [t1_lo, t1_hi])."""
312
+ # Distortion when the intervals do not overlap.
313
+ d_disjoint = torch.abs((t1_lo + t1_hi) / 2 - (t0_lo + t0_hi) / 2)
314
+
315
+ # Distortion when the intervals overlap.
316
+ d_overlap = (2 *
317
+ (torch.minimum(t0_hi, t1_hi) ** 3 - torch.maximum(t0_lo, t1_lo) ** 3) +
318
+ 3 * (t1_hi * t0_hi * torch.abs(t1_hi - t0_hi) +
319
+ t1_lo * t0_lo * torch.abs(t1_lo - t0_lo) + t1_hi * t0_lo *
320
+ (t0_lo - t1_hi) + t1_lo * t0_hi *
321
+ (t1_lo - t0_hi))) / (6 * (t0_hi - t0_lo) * (t1_hi - t1_lo))
322
+
323
+ # Are the two intervals not overlapping?
324
+ are_disjoint = (t0_lo > t1_hi) | (t1_lo > t0_hi)
325
+
326
+ return torch.where(are_disjoint, d_disjoint, d_overlap)
327
+
328
+
329
+ def weighted_percentile(t, w, ps):
330
+ """Compute the weighted percentiles of a step function. w's must sum to 1."""
331
+ cw = integrate_weights(w)
332
+ # We want to interpolate into the integrated weights according to `ps`.
333
+ fn = lambda cw_i, t_i: math.sorted_interp(torch.tensor(ps, device=t.device) / 100, cw_i, t_i)
334
+ # Vmap fn to an arbitrary number of leading dimensions.
335
+ cw_mat = cw.reshape([-1, cw.shape[-1]])
336
+ t_mat = t.reshape([-1, t.shape[-1]])
337
+ wprctile_mat = fn(cw_mat, t_mat) # TODO
338
+ wprctile = wprctile_mat.reshape(cw.shape[:-1] + (len(ps),))
339
+ return wprctile
340
+
341
+
342
+ def resample(t, tp, vp, use_avg=False):
343
+ """Resample a step function defined by (tp, vp) into intervals t.
344
+
345
+ Args:
346
+ t: tensor with shape (..., n+1), the endpoints to resample into.
347
+ tp: tensor with shape (..., m+1), the endpoints of the step function being
348
+ resampled.
349
+ vp: tensor with shape (..., m), the values of the step function being
350
+ resampled.
351
+ use_avg: bool, if False, return the sum of the step function for each
352
+ interval in `t`. If True, return the average, weighted by the width of
353
+ each interval in `t`.
354
+ eps: float, a small value to prevent division by zero when use_avg=True.
355
+
356
+ Returns:
357
+ v: tensor with shape (..., n), the values of the resampled step function.
358
+ """
359
+ eps = torch.finfo(t.dtype).eps
360
+ # eps = 1e-3
361
+
362
+ if use_avg:
363
+ wp = torch.diff(tp, dim=-1)
364
+ v_numer = resample(t, tp, vp * wp, use_avg=False)
365
+ v_denom = resample(t, tp, wp, use_avg=False)
366
+ v = v_numer / v_denom.clamp_min(eps)
367
+ return v
368
+
369
+ acc = torch.cumsum(vp, dim=-1)
370
+ acc0 = torch.cat([torch.zeros(acc.shape[:-1] + (1,), device=acc.device), acc], dim=-1)
371
+ acc0_resampled = math.sorted_interp(t, tp, acc0) # TODO
372
+ v = torch.diff(acc0_resampled, dim=-1)
373
+ return v
374
+
375
+
376
+ def resample_np(t, tp, vp, use_avg=False):
377
+ """
378
+ numpy version of resample
379
+ """
380
+ eps = np.finfo(t.dtype).eps
381
+ if use_avg:
382
+ wp = np.diff(tp, axis=-1)
383
+ v_numer = resample_np(t, tp, vp * wp, use_avg=False)
384
+ v_denom = resample_np(t, tp, wp, use_avg=False)
385
+ v = v_numer / np.maximum(eps, v_denom)
386
+ return v
387
+
388
+ acc = np.cumsum(vp, axis=-1)
389
+ acc0 = np.concatenate([np.zeros(acc.shape[:-1] + (1,)), acc], axis=-1)
390
+ acc0_resampled = np.vectorize(np.interp, signature='(n),(m),(m)->(n)')(t, tp, acc0)
391
+ v = np.diff(acc0_resampled, axis=-1)
392
+ return v
393
+
394
+
395
+ def blur_stepfun(x, y, r):
396
+ xr, xr_idx = torch.sort(torch.cat([x - r, x + r], dim=-1))
397
+ y1 = (torch.cat([y, torch.zeros_like(y[..., :1])], dim=-1) -
398
+ torch.cat([torch.zeros_like(y[..., :1]), y], dim=-1)) / (2 * r)
399
+ y2 = torch.cat([y1, -y1], dim=-1).take_along_dim(xr_idx[..., :-1], dim=-1)
400
+ yr = torch.cumsum((xr[..., 1:] - xr[..., :-1]) *
401
+ torch.cumsum(y2, dim=-1), dim=-1).clamp_min(0)
402
+ yr = torch.cat([torch.zeros_like(yr[..., :1]), yr], dim=-1)
403
+ return xr, yr
utils/system_utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ from errno import EEXIST
13
+ from os import makedirs, path
14
+ import os
15
+
16
+ def mkdir_p(folder_path):
17
+ # Creates a directory. equivalent to using mkdir -p on the command line
18
+ try:
19
+ makedirs(folder_path)
20
+ except OSError as exc: # Python >2.5
21
+ if exc.errno == EEXIST and path.isdir(folder_path):
22
+ pass
23
+ else:
24
+ raise
25
+
26
+ def searchForMaxIteration(folder):
27
+ saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)]
28
+ return max(saved_iters)
utils/trajectories.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import roma
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def rt_to_mat4(
8
+ R: torch.Tensor, t: torch.Tensor, s: torch.Tensor | None = None
9
+ ) -> torch.Tensor:
10
+ """
11
+ Args:
12
+ R (torch.Tensor): (..., 3, 3).
13
+ t (torch.Tensor): (..., 3).
14
+ s (torch.Tensor): (...,).
15
+
16
+ Returns:
17
+ torch.Tensor: (..., 4, 4)
18
+ """
19
+ mat34 = torch.cat([R, t[..., None]], dim=-1)
20
+ if s is None:
21
+ bottom = (
22
+ mat34.new_tensor([[0.0, 0.0, 0.0, 1.0]])
23
+ .reshape((1,) * (mat34.dim() - 2) + (1, 4))
24
+ .expand(mat34.shape[:-2] + (1, 4))
25
+ )
26
+ else:
27
+ bottom = F.pad(1.0 / s[..., None, None], (3, 0), value=0.0)
28
+ mat4 = torch.cat([mat34, bottom], dim=-2)
29
+ return mat4
30
+
31
+ def get_avg_w2c(w2cs: torch.Tensor):
32
+ c2ws = torch.linalg.inv(w2cs)
33
+ # 1. Compute the center
34
+ center = c2ws[:, :3, -1].mean(0)
35
+ # 2. Compute the z axis
36
+ z = F.normalize(c2ws[:, :3, 2].mean(0), dim=-1)
37
+ # 3. Compute axis y' (no need to normalize as it's not the final output)
38
+ y_ = c2ws[:, :3, 1].mean(0) # (3)
39
+ # 4. Compute the x axis
40
+ x = F.normalize(torch.cross(y_, z, dim=-1), dim=-1) # (3)
41
+ # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
42
+ y = torch.cross(z, x, dim=-1) # (3)
43
+ avg_c2w = rt_to_mat4(torch.stack([x, y, z], 1), center)
44
+ avg_w2c = torch.linalg.inv(avg_c2w)
45
+ return avg_w2c
46
+
47
+
48
+ # def get_lookat(origins: torch.Tensor, viewdirs: torch.Tensor) -> torch.Tensor:
49
+ # """Calculate the intersection point of multiple camera rays as the lookat point.
50
+
51
+ # Use the center of camera positions as a reference point for the lookat,
52
+ # then move forward along the average view direction by a certain distance.
53
+ # """
54
+ # # Calculate the center of camera positions
55
+ # center = origins.mean(dim=0)
56
+
57
+ # # Calculate average view direction
58
+ # mean_dir = F.normalize(viewdirs.mean(dim=0), dim=-1)
59
+
60
+ # # Calculate average distance to the center point
61
+ # avg_dist = torch.norm(origins - center, dim=-1).mean()
62
+
63
+ # # Move forward along the average view direction
64
+ # lookat = center + mean_dir * avg_dist
65
+
66
+ # return lookat
67
+
68
+ def get_lookat(origins: torch.Tensor, viewdirs: torch.Tensor) -> torch.Tensor:
69
+ """Triangulate a set of rays to find a single lookat point.
70
+
71
+ Args:
72
+ origins (torch.Tensor): A (N, 3) array of ray origins.
73
+ viewdirs (torch.Tensor): A (N, 3) array of ray view directions.
74
+
75
+ Returns:
76
+ torch.Tensor: A (3,) lookat point.
77
+ """
78
+
79
+ viewdirs = torch.nn.functional.normalize(viewdirs, dim=-1)
80
+ eye = torch.eye(3, device=origins.device, dtype=origins.dtype)[None]
81
+ # Calculate projection matrix I - rr^T
82
+ I_min_cov = eye - (viewdirs[..., None] * viewdirs[..., None, :])
83
+ # Compute sum of projections
84
+ sum_proj = I_min_cov.matmul(origins[..., None]).sum(dim=-3)
85
+ # Solve for the intersection point using least squares
86
+ lookat = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
87
+ # Check NaNs.
88
+ assert not torch.any(torch.isnan(lookat))
89
+ return lookat
90
+
91
+
92
+ def get_lookat_w2cs(positions: torch.Tensor, lookat: torch.Tensor, up: torch.Tensor):
93
+ """
94
+ Args:
95
+ positions: (N, 3) tensor of camera positions
96
+ lookat: (3,) tensor of lookat point
97
+ up: (3,) tensor of up vector
98
+
99
+ Returns:
100
+ w2cs: (N, 3, 3) tensor of world to camera rotation matrices
101
+ """
102
+ forward_vectors = F.normalize(lookat - positions, dim=-1)
103
+ right_vectors = F.normalize(torch.cross(forward_vectors, up[None], dim=-1), dim=-1)
104
+ down_vectors = F.normalize(
105
+ torch.cross(forward_vectors, right_vectors, dim=-1), dim=-1
106
+ )
107
+ Rs = torch.stack([right_vectors, down_vectors, forward_vectors], dim=-1)
108
+ w2cs = torch.linalg.inv(rt_to_mat4(Rs, positions))
109
+ return w2cs
110
+
111
+
112
+ def get_arc_w2cs(
113
+ ref_w2c: torch.Tensor,
114
+ lookat: torch.Tensor,
115
+ up: torch.Tensor,
116
+ num_frames: int,
117
+ degree: float,
118
+ **_,
119
+ ) -> torch.Tensor:
120
+ ref_position = torch.linalg.inv(ref_w2c)[:3, 3]
121
+ thetas = (
122
+ torch.sin(
123
+ torch.linspace(0.0, torch.pi * 2.0, num_frames + 1, device=ref_w2c.device)[
124
+ :-1
125
+ ]
126
+ )
127
+ * (degree / 2.0)
128
+ / 180.0
129
+ * torch.pi
130
+ )
131
+ positions = torch.einsum(
132
+ "nij,j->ni",
133
+ roma.rotvec_to_rotmat(thetas[:, None] * up[None]),
134
+ ref_position - lookat,
135
+ )
136
+ return get_lookat_w2cs(positions, lookat, up)
137
+
138
+
139
+ def get_lemniscate_w2cs(
140
+ ref_w2c: torch.Tensor,
141
+ lookat: torch.Tensor,
142
+ up: torch.Tensor,
143
+ num_frames: int,
144
+ degree: float,
145
+ **_,
146
+ ) -> torch.Tensor:
147
+ ref_c2w = torch.linalg.inv(ref_w2c)
148
+ a = torch.linalg.norm(ref_c2w[:3, 3] - lookat) * np.tan(degree / 360 * np.pi)
149
+ # Lemniscate curve in camera space. Starting at the origin.
150
+ thetas = (
151
+ torch.linspace(0, 2 * torch.pi, num_frames + 1, device=ref_w2c.device)[:-1]
152
+ + torch.pi / 2
153
+ )
154
+
155
+ positions = torch.stack(
156
+ [
157
+ a * torch.cos(thetas) / (1 + torch.sin(thetas) ** 2),
158
+ a * torch.cos(thetas) * torch.sin(thetas) / (1 + torch.sin(thetas) ** 2),
159
+ torch.zeros(num_frames, device=ref_w2c.device),
160
+ ],
161
+ dim=-1,
162
+ )
163
+ # Transform to world space.
164
+ positions = torch.einsum(
165
+ "ij,nj->ni", ref_c2w[:3], F.pad(positions, (0, 1), value=1.0)
166
+ )
167
+ return get_lookat_w2cs(positions, lookat, up)
168
+
169
+
170
+ def get_spiral_w2cs(
171
+ ref_w2c: torch.Tensor,
172
+ lookat: torch.Tensor,
173
+ up: torch.Tensor,
174
+ num_frames: int,
175
+ rads: float | torch.Tensor,
176
+ zrate: float,
177
+ rots: int,
178
+ **_,
179
+ ) -> torch.Tensor:
180
+ ref_c2w = torch.linalg.inv(ref_w2c)
181
+ thetas = torch.linspace(
182
+ 0, 2 * torch.pi * rots, num_frames + 1, device=ref_w2c.device
183
+ )[:-1]
184
+ # Spiral curve in camera space. Starting at the origin.
185
+ if isinstance(rads, torch.Tensor):
186
+ rads = rads.reshape(-1, 3).to(ref_w2c.device)
187
+ positions = (
188
+ torch.stack(
189
+ [
190
+ torch.cos(thetas),
191
+ -torch.sin(thetas),
192
+ -torch.sin(thetas * zrate),
193
+ ],
194
+ dim=-1,
195
+ )
196
+ * rads
197
+ )
198
+ # Transform to world space.
199
+ positions = torch.einsum(
200
+ "ij,nj->ni", ref_c2w[:3], F.pad(positions, (0, 1), value=1.0)
201
+ )
202
+ return get_lookat_w2cs(positions, lookat, up)
203
+
204
+
205
+ def get_wander_w2cs(ref_w2c, focal_length, num_frames, max_disp, **_):
206
+ device = ref_w2c.device
207
+ c2w = np.linalg.inv(ref_w2c.detach().cpu().numpy())
208
+ max_disp = max_disp
209
+
210
+ max_trans = max_disp / focal_length
211
+ output_poses = []
212
+
213
+ for i in range(num_frames):
214
+ x_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_frames))
215
+ y_trans = 0.0
216
+ z_trans = max_trans * np.cos(2.0 * np.pi * float(i) / float(num_frames)) / 2.0
217
+
218
+ i_pose = np.concatenate(
219
+ [
220
+ np.concatenate(
221
+ [
222
+ np.eye(3),
223
+ np.array([x_trans, y_trans, z_trans])[:, np.newaxis],
224
+ ],
225
+ axis=1,
226
+ ),
227
+ np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :],
228
+ ],
229
+ axis=0,
230
+ )
231
+
232
+ i_pose = np.linalg.inv(i_pose)
233
+
234
+ ref_pose = np.concatenate(
235
+ [c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0
236
+ )
237
+
238
+ render_pose = np.dot(ref_pose, i_pose)
239
+ output_poses.append(render_pose)
240
+ output_poses = torch.from_numpy(np.array(output_poses, dtype=np.float32)).to(device)
241
+ w2cs = torch.linalg.inv(output_poses)
242
+
243
+ return w2cs
utils/utils_poses/ATE/align_trajectory.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python2
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import numpy as np
5
+ import utils.utils_poses.ATE.transformations as tfs
6
+
7
+
8
+ def get_best_yaw(C):
9
+ '''
10
+ maximize trace(Rz(theta) * C)
11
+ '''
12
+ assert C.shape == (3, 3)
13
+
14
+ A = C[0, 1] - C[1, 0]
15
+ B = C[0, 0] + C[1, 1]
16
+ theta = np.pi / 2 - np.arctan2(B, A)
17
+
18
+ return theta
19
+
20
+
21
+ def rot_z(theta):
22
+ R = tfs.rotation_matrix(theta, [0, 0, 1])
23
+ R = R[0:3, 0:3]
24
+
25
+ return R
26
+
27
+
28
+ def align_umeyama(model, data, known_scale=False, yaw_only=False):
29
+ """Implementation of the paper: S. Umeyama, Least-Squares Estimation
30
+ of Transformation Parameters Between Two Point Patterns,
31
+ IEEE Trans. Pattern Anal. Mach. Intell., vol. 13, no. 4, 1991.
32
+
33
+ model = s * R * data + t
34
+
35
+ Input:
36
+ model -- first trajectory (nx3), numpy array type
37
+ data -- second trajectory (nx3), numpy array type
38
+
39
+ Output:
40
+ s -- scale factor (scalar)
41
+ R -- rotation matrix (3x3)
42
+ t -- translation vector (3x1)
43
+ t_error -- translational error per point (1xn)
44
+
45
+ """
46
+
47
+ # substract mean
48
+ mu_M = model.mean(0)
49
+ mu_D = data.mean(0)
50
+ model_zerocentered = model - mu_M
51
+ data_zerocentered = data - mu_D
52
+ n = np.shape(model)[0]
53
+
54
+ # correlation
55
+ C = 1.0/n*np.dot(model_zerocentered.transpose(), data_zerocentered)
56
+ sigma2 = 1.0/n*np.multiply(data_zerocentered, data_zerocentered).sum()
57
+ U_svd, D_svd, V_svd = np.linalg.linalg.svd(C)
58
+
59
+ D_svd = np.diag(D_svd)
60
+ V_svd = np.transpose(V_svd)
61
+
62
+ S = np.eye(3)
63
+ if(np.linalg.det(U_svd)*np.linalg.det(V_svd) < 0):
64
+ S[2, 2] = -1
65
+
66
+ if yaw_only:
67
+ rot_C = np.dot(data_zerocentered.transpose(), model_zerocentered)
68
+ theta = get_best_yaw(rot_C)
69
+ R = rot_z(theta)
70
+ else:
71
+ R = np.dot(U_svd, np.dot(S, np.transpose(V_svd)))
72
+
73
+ if known_scale:
74
+ s = 1
75
+ else:
76
+ s = 1.0/sigma2*np.trace(np.dot(D_svd, S))
77
+
78
+ t = mu_M-s*np.dot(R, mu_D)
79
+
80
+ return s, R, t
utils/utils_poses/ATE/align_utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python2
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import numpy as np
5
+
6
+ import utils.utils_poses.ATE.transformations as tfs
7
+ import utils.utils_poses.ATE.align_trajectory as align
8
+
9
+
10
+ def _getIndices(n_aligned, total_n):
11
+ if n_aligned == -1:
12
+ idxs = np.arange(0, total_n)
13
+ else:
14
+ assert n_aligned <= total_n and n_aligned >= 1
15
+ idxs = np.arange(0, n_aligned)
16
+ return idxs
17
+
18
+
19
+ def alignPositionYawSingle(p_es, p_gt, q_es, q_gt):
20
+ '''
21
+ calcualte the 4DOF transformation: yaw R and translation t so that:
22
+ gt = R * est + t
23
+ '''
24
+
25
+ p_es_0, q_es_0 = p_es[0, :], q_es[0, :]
26
+ p_gt_0, q_gt_0 = p_gt[0, :], q_gt[0, :]
27
+ g_rot = tfs.quaternion_matrix(q_gt_0)
28
+ g_rot = g_rot[0:3, 0:3]
29
+ est_rot = tfs.quaternion_matrix(q_es_0)
30
+ est_rot = est_rot[0:3, 0:3]
31
+
32
+ C_R = np.dot(est_rot, g_rot.transpose())
33
+ theta = align.get_best_yaw(C_R)
34
+ R = align.rot_z(theta)
35
+ t = p_gt_0 - np.dot(R, p_es_0)
36
+
37
+ return R, t
38
+
39
+
40
+ def alignPositionYaw(p_es, p_gt, q_es, q_gt, n_aligned=1):
41
+ if n_aligned == 1:
42
+ R, t = alignPositionYawSingle(p_es, p_gt, q_es, q_gt)
43
+ return R, t
44
+ else:
45
+ idxs = _getIndices(n_aligned, p_es.shape[0])
46
+ est_pos = p_es[idxs, 0:3]
47
+ gt_pos = p_gt[idxs, 0:3]
48
+ _, R, t = align.align_umeyama(gt_pos, est_pos, known_scale=True,
49
+ yaw_only=True) # note the order
50
+ t = np.array(t)
51
+ t = t.reshape((3, ))
52
+ R = np.array(R)
53
+ return R, t
54
+
55
+
56
+ # align by a SE3 transformation
57
+ def alignSE3Single(p_es, p_gt, q_es, q_gt):
58
+ '''
59
+ Calculate SE3 transformation R and t so that:
60
+ gt = R * est + t
61
+ Using only the first poses of est and gt
62
+ '''
63
+
64
+ p_es_0, q_es_0 = p_es[0, :], q_es[0, :]
65
+ p_gt_0, q_gt_0 = p_gt[0, :], q_gt[0, :]
66
+
67
+ g_rot = tfs.quaternion_matrix(q_gt_0)
68
+ g_rot = g_rot[0:3, 0:3]
69
+ est_rot = tfs.quaternion_matrix(q_es_0)
70
+ est_rot = est_rot[0:3, 0:3]
71
+
72
+ R = np.dot(g_rot, np.transpose(est_rot))
73
+ t = p_gt_0 - np.dot(R, p_es_0)
74
+
75
+ return R, t
76
+
77
+
78
+ def alignSE3(p_es, p_gt, q_es, q_gt, n_aligned=-1):
79
+ '''
80
+ Calculate SE3 transformation R and t so that:
81
+ gt = R * est + t
82
+ '''
83
+ if n_aligned == 1:
84
+ R, t = alignSE3Single(p_es, p_gt, q_es, q_gt)
85
+ return R, t
86
+ else:
87
+ idxs = _getIndices(n_aligned, p_es.shape[0])
88
+ est_pos = p_es[idxs, 0:3]
89
+ gt_pos = p_gt[idxs, 0:3]
90
+ s, R, t = align.align_umeyama(gt_pos, est_pos,
91
+ known_scale=True) # note the order
92
+ t = np.array(t)
93
+ t = t.reshape((3, ))
94
+ R = np.array(R)
95
+ return R, t
96
+
97
+
98
+ # align by similarity transformation
99
+ def alignSIM3(p_es, p_gt, q_es, q_gt, n_aligned=-1):
100
+ '''
101
+ calculate s, R, t so that:
102
+ gt = R * s * est + t
103
+ '''
104
+ idxs = _getIndices(n_aligned, p_es.shape[0])
105
+ est_pos = p_es[idxs, 0:3]
106
+ gt_pos = p_gt[idxs, 0:3]
107
+ s, R, t = align.align_umeyama(gt_pos, est_pos) # note the order
108
+ return s, R, t
109
+
110
+
111
+ # a general interface
112
+ def alignTrajectory(p_es, p_gt, q_es, q_gt, method, n_aligned=-1):
113
+ '''
114
+ calculate s, R, t so that:
115
+ gt = R * s * est + t
116
+ method can be: sim3, se3, posyaw, none;
117
+ n_aligned: -1 means using all the frames
118
+ '''
119
+ assert p_es.shape[1] == 3
120
+ assert p_gt.shape[1] == 3
121
+ assert q_es.shape[1] == 4
122
+ assert q_gt.shape[1] == 4
123
+
124
+ s = 1
125
+ R = None
126
+ t = None
127
+ if method == 'sim3':
128
+ assert n_aligned >= 2 or n_aligned == -1, "sim3 uses at least 2 frames"
129
+ s, R, t = alignSIM3(p_es, p_gt, q_es, q_gt, n_aligned)
130
+ elif method == 'se3':
131
+ R, t = alignSE3(p_es, p_gt, q_es, q_gt, n_aligned)
132
+ elif method == 'posyaw':
133
+ R, t = alignPositionYaw(p_es, p_gt, q_es, q_gt, n_aligned)
134
+ elif method == 'none':
135
+ R = np.identity(3)
136
+ t = np.zeros((3, ))
137
+ else:
138
+ assert False, 'unknown alignment method'
139
+
140
+ return s, R, t
141
+
142
+
143
+ if __name__ == '__main__':
144
+ pass
utils/utils_poses/ATE/compute_trajectory_errors.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python2
2
+
3
+ import os
4
+ import numpy as np
5
+
6
+ import utils.utils_poses.ATE.trajectory_utils as tu
7
+ import utils.utils_poses.ATE.transformations as tf
8
+
9
+
10
+ def compute_relative_error(p_es, q_es, p_gt, q_gt, T_cm, dist, max_dist_diff,
11
+ accum_distances=[],
12
+ scale=1.0):
13
+
14
+ if len(accum_distances) == 0:
15
+ accum_distances = tu.get_distance_from_start(p_gt)
16
+ comparisons = tu.compute_comparison_indices_length(
17
+ accum_distances, dist, max_dist_diff)
18
+
19
+ n_samples = len(comparisons)
20
+ print('number of samples = {0} '.format(n_samples))
21
+ if n_samples < 2:
22
+ print("Too few samples! Will not compute.")
23
+ return np.array([]), np.array([]), np.array([]), np.array([]), np.array([]),\
24
+ np.array([]), np.array([])
25
+
26
+ T_mc = np.linalg.inv(T_cm)
27
+ errors = []
28
+ for idx, c in enumerate(comparisons):
29
+ if not c == -1:
30
+ T_c1 = tu.get_rigid_body_trafo(q_es[idx, :], p_es[idx, :])
31
+ T_c2 = tu.get_rigid_body_trafo(q_es[c, :], p_es[c, :])
32
+ T_c1_c2 = np.dot(np.linalg.inv(T_c1), T_c2)
33
+ T_c1_c2[:3, 3] *= scale
34
+
35
+ T_m1 = tu.get_rigid_body_trafo(q_gt[idx, :], p_gt[idx, :])
36
+ T_m2 = tu.get_rigid_body_trafo(q_gt[c, :], p_gt[c, :])
37
+ T_m1_m2 = np.dot(np.linalg.inv(T_m1), T_m2)
38
+
39
+ T_m1_m2_in_c1 = np.dot(T_cm, np.dot(T_m1_m2, T_mc))
40
+ T_error_in_c2 = np.dot(np.linalg.inv(T_m1_m2_in_c1), T_c1_c2)
41
+ T_c2_rot = np.eye(4)
42
+ T_c2_rot[0:3, 0:3] = T_c2[0:3, 0:3]
43
+ T_error_in_w = np.dot(T_c2_rot, np.dot(
44
+ T_error_in_c2, np.linalg.inv(T_c2_rot)))
45
+ errors.append(T_error_in_w)
46
+
47
+ error_trans_norm = []
48
+ error_trans_perc = []
49
+ error_yaw = []
50
+ error_gravity = []
51
+ e_rot = []
52
+ e_rot_deg_per_m = []
53
+ for e in errors:
54
+ tn = np.linalg.norm(e[0:3, 3])
55
+ error_trans_norm.append(tn)
56
+ error_trans_perc.append(tn / dist * 100)
57
+ ypr_angles = tf.euler_from_matrix(e, 'rzyx')
58
+ e_rot.append(tu.compute_angle(e))
59
+ error_yaw.append(abs(ypr_angles[0])*180.0/np.pi)
60
+ error_gravity.append(
61
+ np.sqrt(ypr_angles[1]**2+ypr_angles[2]**2)*180.0/np.pi)
62
+ e_rot_deg_per_m.append(e_rot[-1] / dist)
63
+ return errors, np.array(error_trans_norm), np.array(error_trans_perc),\
64
+ np.array(error_yaw), np.array(error_gravity), np.array(e_rot),\
65
+ np.array(e_rot_deg_per_m)
66
+
67
+
68
+ def compute_absolute_error(p_es_aligned, q_es_aligned, p_gt, q_gt):
69
+ e_trans_vec = (p_gt-p_es_aligned)
70
+ e_trans = np.sqrt(np.sum(e_trans_vec**2, 1))
71
+
72
+
73
+ # orientation error
74
+ e_rot = np.zeros((len(e_trans,)))
75
+ e_ypr = np.zeros(np.shape(p_es_aligned))
76
+ for i in range(np.shape(p_es_aligned)[0]):
77
+ R_we = tf.matrix_from_quaternion(q_es_aligned[i, :])
78
+ R_wg = tf.matrix_from_quaternion(q_gt[i, :])
79
+ e_R = np.dot(R_we, np.linalg.inv(R_wg))
80
+ e_ypr[i, :] = tf.euler_from_matrix(e_R, 'rzyx')
81
+ e_rot[i] = np.rad2deg(np.linalg.norm(tf.logmap_so3(e_R[:3, :3])))
82
+ # scale drift
83
+ motion_gt = np.diff(p_gt, 0)
84
+ motion_es = np.diff(p_es_aligned, 0)
85
+ dist_gt = np.sqrt(np.sum(np.multiply(motion_gt, motion_gt), 1))
86
+ dist_es = np.sqrt(np.sum(np.multiply(motion_es, motion_es), 1))
87
+ e_scale_perc = np.abs((np.divide(dist_es, dist_gt)-1.0) * 100)
88
+ # ate = np.sqrt(np.mean(np.asarray(e_trans) ** 2))
89
+ return e_trans, e_trans_vec, e_rot, e_ypr, e_scale_perc
utils/utils_poses/ATE/results_writer.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python2
2
+ import os
3
+ # import yaml
4
+ import numpy as np
5
+
6
+
7
+ def compute_statistics(data_vec):
8
+ stats = dict()
9
+ if len(data_vec) > 0:
10
+ stats['rmse'] = float(
11
+ np.sqrt(np.dot(data_vec, data_vec) / len(data_vec)))
12
+ stats['mean'] = float(np.mean(data_vec))
13
+ stats['median'] = float(np.median(data_vec))
14
+ stats['std'] = float(np.std(data_vec))
15
+ stats['min'] = float(np.min(data_vec))
16
+ stats['max'] = float(np.max(data_vec))
17
+ stats['num_samples'] = int(len(data_vec))
18
+ else:
19
+ stats['rmse'] = 0
20
+ stats['mean'] = 0
21
+ stats['median'] = 0
22
+ stats['std'] = 0
23
+ stats['min'] = 0
24
+ stats['max'] = 0
25
+ stats['num_samples'] = 0
26
+
27
+ return stats
28
+
29
+
30
+ # def update_and_save_stats(new_stats, label, yaml_filename):
31
+ # stats = dict()
32
+ # if os.path.exists(yaml_filename):
33
+ # stats = yaml.load(open(yaml_filename, 'r'), Loader=yaml.FullLoader)
34
+ # stats[label] = new_stats
35
+ #
36
+ # with open(yaml_filename, 'w') as outfile:
37
+ # outfile.write(yaml.dump(stats, default_flow_style=False))
38
+ #
39
+ # return
40
+ #
41
+ #
42
+ # def compute_and_save_statistics(data_vec, label, yaml_filename):
43
+ # new_stats = compute_statistics(data_vec)
44
+ # update_and_save_stats(new_stats, label, yaml_filename)
45
+ #
46
+ # return new_stats
47
+ #
48
+ #
49
+ # def write_tex_table(list_values, rows, cols, outfn):
50
+ # '''
51
+ # write list_values[row_idx][col_idx] to a table that is ready to be pasted
52
+ # into latex source
53
+ #
54
+ # list_values is a list of row values
55
+ #
56
+ # The value should be string of desired format
57
+ # '''
58
+ #
59
+ # assert len(rows) >= 1
60
+ # assert len(cols) >= 1
61
+ #
62
+ # with open(outfn, 'w') as f:
63
+ # # write header
64
+ # f.write(' & ')
65
+ # for col_i in cols[:-1]:
66
+ # f.write(col_i + ' & ')
67
+ # f.write(' ' + cols[-1]+'\n')
68
+ #
69
+ # # write each row
70
+ # for row_idx, row_i in enumerate(list_values):
71
+ # f.write(rows[row_idx] + ' & ')
72
+ # row_values = list_values[row_idx]
73
+ # for col_idx in range(len(row_values) - 1):
74
+ # f.write(row_values[col_idx] + ' & ')
75
+ # f.write(' ' + row_values[-1]+' \n')
utils/utils_poses/ATE/trajectory_utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python2
2
+ """
3
+ @author: Christian Forster
4
+ """
5
+
6
+ import os
7
+ import numpy as np
8
+ import utils.utils_poses.ATE.transformations as tf
9
+
10
+
11
+ def get_rigid_body_trafo(quat, trans):
12
+ T = tf.quaternion_matrix(quat)
13
+ T[0:3, 3] = trans
14
+ return T
15
+
16
+
17
+ def get_distance_from_start(gt_translation):
18
+ distances = np.diff(gt_translation[:, 0:3], axis=0)
19
+ distances = np.sqrt(np.sum(np.multiply(distances, distances), 1))
20
+ distances = np.cumsum(distances)
21
+ distances = np.concatenate(([0], distances))
22
+ return distances
23
+
24
+
25
+ def compute_comparison_indices_length(distances, dist, max_dist_diff):
26
+ max_idx = len(distances)
27
+ comparisons = []
28
+ for idx, d in enumerate(distances):
29
+ best_idx = -1
30
+ error = max_dist_diff
31
+ for i in range(idx, max_idx):
32
+ if np.abs(distances[i]-(d+dist)) < error:
33
+ best_idx = i
34
+ error = np.abs(distances[i] - (d+dist))
35
+ if best_idx != -1:
36
+ comparisons.append(best_idx)
37
+ return comparisons
38
+
39
+
40
+ def compute_angle(transform):
41
+ """
42
+ Compute the rotation angle from a 4x4 homogeneous matrix.
43
+ """
44
+ # an invitation to 3-d vision, p 27
45
+ return np.arccos(
46
+ min(1, max(-1, (np.trace(transform[0:3, 0:3]) - 1)/2)))*180.0/np.pi
utils/utils_poses/ATE/transformations.py ADDED
@@ -0,0 +1,1974 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # transformations.py
3
+
4
+ # Copyright (c) 2006, Christoph Gohlke
5
+ # Copyright (c) 2006-2009, The Regents of the University of California
6
+ # All rights reserved.
7
+ #
8
+ # Redistribution and use in source and binary forms, with or without
9
+ # modification, are permitted provided that the following conditions are met:
10
+ #
11
+ # * Redistributions of source code must retain the above copyright
12
+ # notice, this list of conditions and the following disclaimer.
13
+ # * Redistributions in binary form must reproduce the above copyright
14
+ # notice, this list of conditions and the following disclaimer in the
15
+ # documentation and/or other materials provided with the distribution.
16
+ # * Neither the name of the copyright holders nor the names of any
17
+ # contributors may be used to endorse or promote products derived
18
+ # from this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
23
+ # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
24
+ # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
25
+ # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
26
+ # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
27
+ # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
28
+ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
29
+ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
30
+ # POSSIBILITY OF SUCH DAMAGE.
31
+
32
+ """Homogeneous Transformation Matrices and Quaternions.
33
+
34
+ A library for calculating 4x4 matrices for translating, rotating, reflecting,
35
+ scaling, shearing, projecting, orthogonalizing, and superimposing arrays of
36
+ 3D homogeneous coordinates as well as for converting between rotation matrices,
37
+ Euler angles, and quaternions. Also includes an Arcball control object and
38
+ functions to decompose transformation matrices.
39
+
40
+ :Authors:
41
+ `Christoph Gohlke <http://www.lfd.uci.edu/~gohlke/>`__,
42
+ Laboratory for Fluorescence Dynamics, University of California, Irvine
43
+
44
+ :Version: 20090418
45
+
46
+ Requirements
47
+ ------------
48
+
49
+ * `Python 2.6 <http://www.python.org>`__
50
+ * `Numpy 1.3 <http://numpy.scipy.org>`__
51
+ * `transformations.c 20090418 <http://www.lfd.uci.edu/~gohlke/>`__
52
+ (optional implementation of some functions in C)
53
+
54
+ Notes
55
+ -----
56
+
57
+ Matrices (M) can be inverted using numpy.linalg.inv(M), concatenated using
58
+ numpy.dot(M0, M1), or used to transform homogeneous coordinates (v) using
59
+ numpy.dot(M, v) for shape (4, \*) "point of arrays", respectively
60
+ numpy.dot(v, M.T) for shape (\*, 4) "array of points".
61
+
62
+ Calculations are carried out with numpy.float64 precision.
63
+
64
+ This Python implementation is not optimized for speed.
65
+
66
+ Vector, point, quaternion, and matrix function arguments are expected to be
67
+ "array like", i.e. tuple, list, or numpy arrays.
68
+
69
+ Return types are numpy arrays unless specified otherwise.
70
+
71
+ Angles are in radians unless specified otherwise.
72
+
73
+ Quaternions ix+jy+kz+w are represented as [x, y, z, w].
74
+
75
+ Use the transpose of transformation matrices for OpenGL glMultMatrixd().
76
+
77
+ A triple of Euler angles can be applied/interpreted in 24 ways, which can
78
+ be specified using a 4 character string or encoded 4-tuple:
79
+
80
+ *Axes 4-string*: e.g. 'sxyz' or 'ryxy'
81
+
82
+ - first character : rotations are applied to 's'tatic or 'r'otating frame
83
+ - remaining characters : successive rotation axis 'x', 'y', or 'z'
84
+
85
+ *Axes 4-tuple*: e.g. (0, 0, 0, 0) or (1, 1, 1, 1)
86
+
87
+ - inner axis: code of axis ('x':0, 'y':1, 'z':2) of rightmost matrix.
88
+ - parity : even (0) if inner axis 'x' is followed by 'y', 'y' is followed
89
+ by 'z', or 'z' is followed by 'x'. Otherwise odd (1).
90
+ - repetition : first and last axis are same (1) or different (0).
91
+ - frame : rotations are applied to static (0) or rotating (1) frame.
92
+
93
+ References
94
+ ----------
95
+
96
+ (1) Matrices and transformations. Ronald Goldman.
97
+ In "Graphics Gems I", pp 472-475. Morgan Kaufmann, 1990.
98
+ (2) More matrices and transformations: shear and pseudo-perspective.
99
+ Ronald Goldman. In "Graphics Gems II", pp 320-323. Morgan Kaufmann, 1991.
100
+ (3) Decomposing a matrix into simple transformations. Spencer Thomas.
101
+ In "Graphics Gems II", pp 320-323. Morgan Kaufmann, 1991.
102
+ (4) Recovering the data from the transformation matrix. Ronald Goldman.
103
+ In "Graphics Gems II", pp 324-331. Morgan Kaufmann, 1991.
104
+ (5) Euler angle conversion. Ken Shoemake.
105
+ In "Graphics Gems IV", pp 222-229. Morgan Kaufmann, 1994.
106
+ (6) Arcball rotation control. Ken Shoemake.
107
+ In "Graphics Gems IV", pp 175-192. Morgan Kaufmann, 1994.
108
+ (7) Representing attitude: Euler angles, unit quaternions, and rotation
109
+ vectors. James Diebel. 2006.
110
+ (8) A discussion of the solution for the best rotation to relate two sets
111
+ of vectors. W Kabsch. Acta Cryst. 1978. A34, 827-828.
112
+ (9) Closed-form solution of absolute orientation using unit quaternions.
113
+ BKP Horn. J Opt Soc Am A. 1987. 4(4), 629-642.
114
+ (10) Quaternions. Ken Shoemake.
115
+ http://www.sfu.ca/~jwa3/cmpt461/files/quatut.pdf
116
+ (11) From quaternion to matrix and back. JMP van Waveren. 2005.
117
+ http://www.intel.com/cd/ids/developer/asmo-na/eng/293748.htm
118
+ (12) Uniform random rotations. Ken Shoemake.
119
+ In "Graphics Gems III", pp 124-132. Morgan Kaufmann, 1992.
120
+
121
+
122
+ Examples
123
+ --------
124
+
125
+ >>> alpha, beta, gamma = 0.123, -1.234, 2.345
126
+ >>> origin, xaxis, yaxis, zaxis = (0, 0, 0), (1, 0, 0), (0, 1, 0), (0, 0, 1)
127
+ >>> I = identity_matrix()
128
+ >>> Rx = rotation_matrix(alpha, xaxis)
129
+ >>> Ry = rotation_matrix(beta, yaxis)
130
+ >>> Rz = rotation_matrix(gamma, zaxis)
131
+ >>> R = concatenate_matrices(Rx, Ry, Rz)
132
+ >>> euler = euler_from_matrix(R, 'rxyz')
133
+ >>> numpy.allclose([alpha, beta, gamma], euler)
134
+ True
135
+ >>> Re = euler_matrix(alpha, beta, gamma, 'rxyz')
136
+ >>> is_same_transform(R, Re)
137
+ True
138
+ >>> al, be, ga = euler_from_matrix(Re, 'rxyz')
139
+ >>> is_same_transform(Re, euler_matrix(al, be, ga, 'rxyz'))
140
+ True
141
+ >>> qx = quaternion_about_axis(alpha, xaxis)
142
+ >>> qy = quaternion_about_axis(beta, yaxis)
143
+ >>> qz = quaternion_about_axis(gamma, zaxis)
144
+ >>> q = quaternion_multiply(qx, qy)
145
+ >>> q = quaternion_multiply(q, qz)
146
+ >>> Rq = quaternion_matrix(q)
147
+ >>> is_same_transform(R, Rq)
148
+ True
149
+ >>> S = scale_matrix(1.23, origin)
150
+ >>> T = translation_matrix((1, 2, 3))
151
+ >>> Z = shear_matrix(beta, xaxis, origin, zaxis)
152
+ >>> R = random_rotation_matrix(numpy.random.rand(3))
153
+ >>> M = concatenate_matrices(T, R, Z, S)
154
+ >>> scale, shear, angles, trans, persp = decompose_matrix(M)
155
+ >>> numpy.allclose(scale, 1.23)
156
+ True
157
+ >>> numpy.allclose(trans, (1, 2, 3))
158
+ True
159
+ >>> numpy.allclose(shear, (0, math.tan(beta), 0))
160
+ True
161
+ >>> is_same_transform(R, euler_matrix(axes='sxyz', *angles))
162
+ True
163
+ >>> M1 = compose_matrix(scale, shear, angles, trans, persp)
164
+ >>> is_same_transform(M, M1)
165
+ True
166
+
167
+ """
168
+
169
+ from __future__ import division
170
+
171
+ import warnings
172
+ import math
173
+
174
+ import numpy
175
+
176
+ # Documentation in HTML format can be generated with Epydoc
177
+ __docformat__ = "restructuredtext en"
178
+
179
+
180
+ def skew(v):
181
+ """Returns the skew-symmetric matrix of a vector
182
+ cfo, 2015/08/13
183
+
184
+ """
185
+ return numpy.array([[0, -v[2], v[1]],
186
+ [v[2], 0, -v[0]],
187
+ [-v[1], v[0], 0]], dtype=numpy.float64)
188
+
189
+
190
+ def unskew(R):
191
+ """Returns the coordinates of a skew-symmetric matrix
192
+ cfo, 2015/08/13
193
+
194
+ """
195
+ return numpy.array([R[2, 1], R[0, 2], R[1, 0]], dtype=numpy.float64)
196
+
197
+
198
+ def first_order_rotation(rotvec):
199
+ """First order approximation of a rotation: I + skew(rotvec)
200
+ cfo, 2015/08/13
201
+
202
+ """
203
+ R = numpy.zeros((3, 3), dtype=numpy.float64)
204
+ R[0, 0] = 1.0
205
+ R[1, 0] = rotvec[2]
206
+ R[2, 0] = -rotvec[1]
207
+ R[0, 1] = -rotvec[2]
208
+ R[1, 1] = 1.0
209
+ R[2, 1] = rotvec[0]
210
+ R[0, 2] = rotvec[1]
211
+ R[1, 2] = -rotvec[0]
212
+ R[2, 2] = 1.0
213
+ return R
214
+
215
+
216
+ def axis_angle(axis, theta):
217
+ """Compute a rotation matrix from an axis and an angle.
218
+ Returns 3x3 Matrix.
219
+ Is the same as transformations.rotation_matrix(theta, axis).
220
+ cfo, 2015/08/13
221
+
222
+ """
223
+ if theta*theta > _EPS:
224
+ wx = axis[0]
225
+ wy = axis[1]
226
+ wz = axis[2]
227
+ costheta = numpy.cos(theta)
228
+ sintheta = numpy.sin(theta)
229
+ c_1 = 1.0 - costheta
230
+ wx_sintheta = wx * sintheta
231
+ wy_sintheta = wy * sintheta
232
+ wz_sintheta = wz * sintheta
233
+ C00 = c_1 * wx * wx
234
+ C01 = c_1 * wx * wy
235
+ C02 = c_1 * wx * wz
236
+ C11 = c_1 * wy * wy
237
+ C12 = c_1 * wy * wz
238
+ C22 = c_1 * wz * wz
239
+ R = numpy.zeros((3, 3), dtype=numpy.float64)
240
+ R[0, 0] = costheta + C00
241
+ R[1, 0] = wz_sintheta + C01
242
+ R[2, 0] = -wy_sintheta + C02
243
+ R[0, 1] = -wz_sintheta + C01
244
+ R[1, 1] = costheta + C11
245
+ R[2, 1] = wx_sintheta + C12
246
+ R[0, 2] = wy_sintheta + C02
247
+ R[1, 2] = -wx_sintheta + C12
248
+ R[2, 2] = costheta + C22
249
+ return R
250
+ else:
251
+ return first_order_rotation(axis*theta)
252
+
253
+
254
+ def expmap_so3(rotvec):
255
+ """Exponential map at identity.
256
+ Create a rotation from canonical coordinates using Rodrigues' formula.
257
+ cfo, 2015/08/13
258
+
259
+ """
260
+ theta = numpy.linalg.norm(rotvec)
261
+ axis = rotvec/theta
262
+ return axis_angle(axis, theta)
263
+
264
+
265
+ def logmap_so3(R):
266
+ """Logmap at the identity.
267
+ Returns canonical coordinates of rotation.
268
+ cfo, 2015/08/13
269
+
270
+ """
271
+ R11 = R[0, 0]
272
+ R12 = R[0, 1]
273
+ R13 = R[0, 2]
274
+ R21 = R[1, 0]
275
+ R22 = R[1, 1]
276
+ R23 = R[1, 2]
277
+ R31 = R[2, 0]
278
+ R32 = R[2, 1]
279
+ R33 = R[2, 2]
280
+ tr = numpy.trace(R)
281
+ omega = numpy.empty((3,), dtype=numpy.float64)
282
+
283
+ # when trace == -1, i.e., when theta = +-pi, +-3pi, +-5pi, we do something
284
+ # special
285
+ if(numpy.abs(tr + 1.0) < 1e-10):
286
+ if(numpy.abs(R33 + 1.0) > 1e-10):
287
+ omega = (numpy.pi / numpy.sqrt(2.0 + 2.0 * R33)) * \
288
+ numpy.array([R13, R23, 1.0+R33])
289
+ elif(numpy.abs(R22 + 1.0) > 1e-10):
290
+ omega = (numpy.pi / numpy.sqrt(2.0 + 2.0 * R22)) * \
291
+ numpy.array([R12, 1.0+R22, R32])
292
+ else:
293
+ omega = (numpy.pi / numpy.sqrt(2.0 + 2.0 * R11)) * \
294
+ numpy.array([1.0+R11, R21, R31])
295
+ else:
296
+ magnitude = 1.0
297
+ tr_3 = tr - 3.0
298
+ if tr_3 < -1e-7:
299
+ theta = numpy.arccos((tr - 1.0) / 2.0)
300
+ magnitude = theta / (2.0 * numpy.sin(theta))
301
+ else:
302
+ # when theta near 0, +-2pi, +-4pi, etc. (trace near 3.0)
303
+ # use Taylor expansion: theta \approx 1/2-(t-3)/12 + O((t-3)^2)
304
+ magnitude = 0.5 - tr_3 * tr_3 / 12.0
305
+
306
+ omega = magnitude * numpy.array([R32 - R23, R13 - R31, R21 - R12])
307
+
308
+ return omega
309
+
310
+
311
+
312
+ def right_jacobian_so3(rotvec):
313
+ """Right Jacobian for Exponential map in SO(3)
314
+ Equation (10.86) and following equations in G.S. Chirikjian, "Stochastic
315
+ Models, Information Theory, and Lie Groups", Volume 2, 2008.
316
+
317
+ > expmap_so3(thetahat + omega) \approx expmap_so3(thetahat) * expmap_so3(Jr * omega)
318
+ where Jr = right_jacobian_so3(thetahat);
319
+ This maps a perturbation in the tangent space (omega) to a perturbation
320
+ on the manifold (expmap_so3(Jr * omega))
321
+ cfo, 2015/08/13
322
+
323
+ """
324
+
325
+ theta2 = numpy.dot(rotvec, rotvec)
326
+ if theta2 <= _EPS:
327
+ return numpy.identity(3, dtype=numpy.float64)
328
+ else:
329
+ theta = numpy.sqrt(theta2)
330
+ Y = skew(rotvec) / theta
331
+ I_3x3 = numpy.identity(3, dtype=numpy.float64)
332
+ J_r = I_3x3 - ((1.0 - numpy.cos(theta)) / theta) * Y + \
333
+ (1.0 - numpy.sin(theta) / theta) * numpy.dot(Y, Y)
334
+ return J_r
335
+
336
+
337
+ def S_inv_eulerZYX_body(euler_coordinates):
338
+ """ Relates angular rates w to changes in eulerZYX coordinates.
339
+ dot(euler) = S^-1(euler_coordinates) * omega
340
+ Also called: rotation-rate matrix. (E in Lupton paper)
341
+ cfo, 2015/08/13
342
+
343
+ """
344
+ y = euler_coordinates[1]
345
+ z = euler_coordinates[2]
346
+ E = numpy.zeros((3, 3))
347
+ E[0, 1] = numpy.sin(z)/numpy.cos(y)
348
+ E[0, 2] = numpy.cos(z)/numpy.cos(y)
349
+ E[1, 1] = numpy.cos(z)
350
+ E[1, 2] = -numpy.sin(z)
351
+ E[2, 0] = 1.0
352
+ E[2, 1] = numpy.sin(z)*numpy.sin(y)/numpy.cos(y)
353
+ E[2, 2] = numpy.cos(z)*numpy.sin(y)/numpy.cos(y)
354
+ return E
355
+
356
+
357
+ def S_inv_eulerZYX_body_deriv(euler_coordinates, omega):
358
+ """ Compute dE(euler_coordinates)*omega/deuler_coordinates
359
+ cfo, 2015/08/13
360
+
361
+ """
362
+
363
+ y = euler_coordinates[1]
364
+ z = euler_coordinates[2]
365
+
366
+ """
367
+ w1 = omega[0]; w2 = omega[1]; w3 = omega[2]
368
+ J = numpy.zeros((3,3))
369
+ J[0,0] = 0
370
+ J[0,1] = math.tan(y) / math.cos(y) * (math.sin(z) * w2 + math.cos(z) * w3)
371
+ J[0,2] = w2/math.cos(y)*math.cos(z) - w3/math.cos(y)*math.sin(z)
372
+ J[1,0] = 0
373
+ J[1,1] = 0
374
+ J[1,2] = -w2*math.sin(z) - w3*math.cos(z)
375
+ J[2,0] = w1
376
+ J[2,1] = 1.0/math.cos(y)**2 * (w2 * math.sin(z) + w3 * math.cos(z))
377
+ J[2,2] = w2*math.tan(y)*math.cos(z) - w3*math.tan(y)*math.sin(z)
378
+
379
+ """
380
+
381
+ # second version, x = psi, y = theta, z = phi
382
+ # J_x = numpy.zeros((3,3))
383
+ J_y = numpy.zeros((3, 3))
384
+ J_z = numpy.zeros((3, 3))
385
+
386
+ # dE^-1/dtheta
387
+ J_y[0, 1] = math.tan(y)/math.cos(y)*math.sin(z)
388
+ J_y[0, 2] = math.tan(y)/math.cos(y)*math.cos(z)
389
+ J_y[2, 1] = math.sin(z)/(math.cos(y))**2
390
+ J_y[2, 2] = math.cos(z)/(math.cos(y))**2
391
+
392
+ # dE^-1/dphi
393
+ J_z[0, 1] = math.cos(z)/math.cos(y)
394
+ J_z[0, 2] = -math.sin(z)/math.cos(y)
395
+ J_z[1, 1] = -math.sin(z)
396
+ J_z[1, 2] = -math.cos(z)
397
+ J_z[2, 1] = math.cos(z)*math.tan(y)
398
+ J_z[2, 2] = -math.sin(z)*math.tan(y)
399
+
400
+ J = numpy.zeros((3, 3))
401
+ J[:, 1] = numpy.dot(J_y, omega)
402
+ J[:, 2] = numpy.dot(J_z, omega)
403
+
404
+ return J
405
+
406
+
407
+ def identity_matrix():
408
+ """Return 4x4 identity/unit matrix.
409
+
410
+ >>> I = identity_matrix()
411
+ >>> numpy.allclose(I, numpy.dot(I, I))
412
+ True
413
+ >>> numpy.sum(I), numpy.trace(I)
414
+ (4.0, 4.0)
415
+ >>> numpy.allclose(I, numpy.identity(4, dtype=numpy.float64))
416
+ True
417
+
418
+ """
419
+ return numpy.identity(4, dtype=numpy.float64)
420
+
421
+
422
+ def translation_matrix(direction):
423
+ """Return matrix to translate by direction vector.
424
+
425
+ >>> v = numpy.random.random(3) - 0.5
426
+ >>> numpy.allclose(v, translation_matrix(v)[:3, 3])
427
+ True
428
+
429
+ """
430
+ M = numpy.identity(4)
431
+ M[:3, 3] = direction[:3]
432
+ return M
433
+
434
+
435
+ def translation_from_matrix(matrix):
436
+ """Return translation vector from translation matrix.
437
+
438
+ >>> v0 = numpy.random.random(3) - 0.5
439
+ >>> v1 = translation_from_matrix(translation_matrix(v0))
440
+ >>> numpy.allclose(v0, v1)
441
+ True
442
+
443
+ """
444
+ return numpy.array(matrix, copy=False)[:3, 3].copy()
445
+
446
+
447
+ def convert_3x3_to_4x4(matrix_3x3):
448
+ M = numpy.identity(4)
449
+ M[:3, :3] = matrix_3x3
450
+ return M
451
+
452
+
453
+ def reflection_matrix(point, normal):
454
+ """Return matrix to mirror at plane defined by point and normal vector.
455
+
456
+ >>> v0 = numpy.random.random(4) - 0.5
457
+ >>> v0[3] = 1.0
458
+ >>> v1 = numpy.random.random(3) - 0.5
459
+ >>> R = reflection_matrix(v0, v1)
460
+ >>> numpy.allclose(2., numpy.trace(R))
461
+ True
462
+ >>> numpy.allclose(v0, numpy.dot(R, v0))
463
+ True
464
+ >>> v2 = v0.copy()
465
+ >>> v2[:3] += v1
466
+ >>> v3 = v0.copy()
467
+ >>> v2[:3] -= v1
468
+ >>> numpy.allclose(v2, numpy.dot(R, v3))
469
+ True
470
+
471
+ """
472
+ normal = unit_vector(normal[:3])
473
+ M = numpy.identity(4)
474
+ M[:3, :3] -= 2.0 * numpy.outer(normal, normal)
475
+ M[:3, 3] = (2.0 * numpy.dot(point[:3], normal)) * normal
476
+ return M
477
+
478
+
479
+ def reflection_from_matrix(matrix):
480
+ """Return mirror plane point and normal vector from reflection matrix.
481
+
482
+ >>> v0 = numpy.random.random(3) - 0.5
483
+ >>> v1 = numpy.random.random(3) - 0.5
484
+ >>> M0 = reflection_matrix(v0, v1)
485
+ >>> point, normal = reflection_from_matrix(M0)
486
+ >>> M1 = reflection_matrix(point, normal)
487
+ >>> is_same_transform(M0, M1)
488
+ True
489
+
490
+ """
491
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)
492
+ # normal: unit eigenvector corresponding to eigenvalue -1
493
+ l, V = numpy.linalg.eig(M[:3, :3])
494
+ i = numpy.where(abs(numpy.real(l) + 1.0) < 1e-8)[0]
495
+ if not len(i):
496
+ raise ValueError("no unit eigenvector corresponding to eigenvalue -1")
497
+ normal = numpy.real(V[:, i[0]]).squeeze()
498
+ # point: any unit eigenvector corresponding to eigenvalue 1
499
+ l, V = numpy.linalg.eig(M)
500
+ i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-8)[0]
501
+ if not len(i):
502
+ raise ValueError("no unit eigenvector corresponding to eigenvalue 1")
503
+ point = numpy.real(V[:, i[-1]]).squeeze()
504
+ point /= point[3]
505
+ return point, normal
506
+
507
+
508
+ def rotation_matrix(angle, direction, point=None):
509
+ """Return matrix to rotate about axis defined by point and direction.
510
+
511
+ >>> angle = (random.random() - 0.5) * (2*math.pi)
512
+ >>> direc = numpy.random.random(3) - 0.5
513
+ >>> point = numpy.random.random(3) - 0.5
514
+ >>> R0 = rotation_matrix(angle, direc, point)
515
+ >>> R1 = rotation_matrix(angle-2*math.pi, direc, point)
516
+ >>> is_same_transform(R0, R1)
517
+ True
518
+ >>> R0 = rotation_matrix(angle, direc, point)
519
+ >>> R1 = rotation_matrix(-angle, -direc, point)
520
+ >>> is_same_transform(R0, R1)
521
+ True
522
+ >>> I = numpy.identity(4, numpy.float64)
523
+ >>> numpy.allclose(I, rotation_matrix(math.pi*2, direc))
524
+ True
525
+ >>> numpy.allclose(2., numpy.trace(rotation_matrix(math.pi/2,
526
+ ... direc, point)))
527
+ True
528
+
529
+ """
530
+ sina = math.sin(angle)
531
+ cosa = math.cos(angle)
532
+ direction = unit_vector(direction[:3])
533
+ # rotation matrix around unit vector
534
+ R = numpy.array(((cosa, 0.0, 0.0),
535
+ (0.0, cosa, 0.0),
536
+ (0.0, 0.0, cosa)), dtype=numpy.float64)
537
+ R += numpy.outer(direction, direction) * (1.0 - cosa)
538
+ direction *= sina
539
+ R += numpy.array(((0.0, -direction[2], direction[1]),
540
+ (direction[2], 0.0, -direction[0]),
541
+ (-direction[1], direction[0], 0.0)),
542
+ dtype=numpy.float64)
543
+ M = numpy.identity(4)
544
+ M[:3, :3] = R
545
+ if point is not None:
546
+ # rotation not around origin
547
+ point = numpy.array(point[:3], dtype=numpy.float64, copy=False)
548
+ M[:3, 3] = point - numpy.dot(R, point)
549
+ return M
550
+
551
+
552
+ def rotation_from_matrix(matrix):
553
+ """Return rotation angle and axis from rotation matrix.
554
+
555
+ >>> angle = (random.random() - 0.5) * (2*math.pi)
556
+ >>> direc = numpy.random.random(3) - 0.5
557
+ >>> point = numpy.random.random(3) - 0.5
558
+ >>> R0 = rotation_matrix(angle, direc, point)
559
+ >>> angle, direc, point = rotation_from_matrix(R0)
560
+ >>> R1 = rotation_matrix(angle, direc, point)
561
+ >>> is_same_transform(R0, R1)
562
+ True
563
+
564
+ """
565
+ R = numpy.array(matrix, dtype=numpy.float64, copy=False)
566
+ R33 = R[:3, :3]
567
+ # direction: unit eigenvector of R33 corresponding to eigenvalue of 1
568
+ l, W = numpy.linalg.eig(R33.T)
569
+ i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-8)[0]
570
+ if not len(i):
571
+ raise ValueError("no unit eigenvector corresponding to eigenvalue 1")
572
+ direction = numpy.real(W[:, i[-1]]).squeeze()
573
+ # point: unit eigenvector of R33 corresponding to eigenvalue of 1
574
+ l, Q = numpy.linalg.eig(R)
575
+ i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-8)[0]
576
+ if not len(i):
577
+ raise ValueError("no unit eigenvector corresponding to eigenvalue 1")
578
+ point = numpy.real(Q[:, i[-1]]).squeeze()
579
+ point /= point[3]
580
+ # rotation angle depending on direction
581
+ cosa = (numpy.trace(R33) - 1.0) / 2.0
582
+ if abs(direction[2]) > 1e-8:
583
+ sina = (R[1, 0] + (cosa-1.0)*direction[0]*direction[1]) / direction[2]
584
+ elif abs(direction[1]) > 1e-8:
585
+ sina = (R[0, 2] + (cosa-1.0)*direction[0]*direction[2]) / direction[1]
586
+ else:
587
+ sina = (R[2, 1] + (cosa-1.0)*direction[1]*direction[2]) / direction[0]
588
+ angle = math.atan2(sina, cosa)
589
+ return angle, direction, point
590
+
591
+
592
+ def scale_matrix(factor, origin=None, direction=None):
593
+ """Return matrix to scale by factor around origin in direction.
594
+
595
+ Use factor -1 for point symmetry.
596
+
597
+ >>> v = (numpy.random.rand(4, 5) - 0.5) * 20.0
598
+ >>> v[3] = 1.0
599
+ >>> S = scale_matrix(-1.234)
600
+ >>> numpy.allclose(numpy.dot(S, v)[:3], -1.234*v[:3])
601
+ True
602
+ >>> factor = random.random() * 10 - 5
603
+ >>> origin = numpy.random.random(3) - 0.5
604
+ >>> direct = numpy.random.random(3) - 0.5
605
+ >>> S = scale_matrix(factor, origin)
606
+ >>> S = scale_matrix(factor, origin, direct)
607
+
608
+ """
609
+ if direction is None:
610
+ # uniform scaling
611
+ M = numpy.array(((factor, 0.0, 0.0, 0.0),
612
+ (0.0, factor, 0.0, 0.0),
613
+ (0.0, 0.0, factor, 0.0),
614
+ (0.0, 0.0, 0.0, 1.0)), dtype=numpy.float64)
615
+ if origin is not None:
616
+ M[:3, 3] = origin[:3]
617
+ M[:3, 3] *= 1.0 - factor
618
+ else:
619
+ # nonuniform scaling
620
+ direction = unit_vector(direction[:3])
621
+ factor = 1.0 - factor
622
+ M = numpy.identity(4)
623
+ M[:3, :3] -= factor * numpy.outer(direction, direction)
624
+ if origin is not None:
625
+ M[:3, 3] = (factor * numpy.dot(origin[:3], direction)) * direction
626
+ return M
627
+
628
+
629
+ def scale_from_matrix(matrix):
630
+ """Return scaling factor, origin and direction from scaling matrix.
631
+
632
+ >>> factor = random.random() * 10 - 5
633
+ >>> origin = numpy.random.random(3) - 0.5
634
+ >>> direct = numpy.random.random(3) - 0.5
635
+ >>> S0 = scale_matrix(factor, origin)
636
+ >>> factor, origin, direction = scale_from_matrix(S0)
637
+ >>> S1 = scale_matrix(factor, origin, direction)
638
+ >>> is_same_transform(S0, S1)
639
+ True
640
+ >>> S0 = scale_matrix(factor, origin, direct)
641
+ >>> factor, origin, direction = scale_from_matrix(S0)
642
+ >>> S1 = scale_matrix(factor, origin, direction)
643
+ >>> is_same_transform(S0, S1)
644
+ True
645
+
646
+ """
647
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)
648
+ M33 = M[:3, :3]
649
+ factor = numpy.trace(M33) - 2.0
650
+ try:
651
+ # direction: unit eigenvector corresponding to eigenvalue factor
652
+ l, V = numpy.linalg.eig(M33)
653
+ i = numpy.where(abs(numpy.real(l) - factor) < 1e-8)[0][0]
654
+ direction = numpy.real(V[:, i]).squeeze()
655
+ direction /= vector_norm(direction)
656
+ except IndexError:
657
+ # uniform scaling
658
+ factor = (factor + 2.0) / 3.0
659
+ direction = None
660
+ # origin: any eigenvector corresponding to eigenvalue 1
661
+ l, V = numpy.linalg.eig(M)
662
+ i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-8)[0]
663
+ if not len(i):
664
+ raise ValueError("no eigenvector corresponding to eigenvalue 1")
665
+ origin = numpy.real(V[:, i[-1]]).squeeze()
666
+ origin /= origin[3]
667
+ return factor, origin, direction
668
+
669
+
670
+ def projection_matrix(point, normal, direction=None,
671
+ perspective=None, pseudo=False):
672
+ """Return matrix to project onto plane defined by point and normal.
673
+
674
+ Using either perspective point, projection direction, or none of both.
675
+
676
+ If pseudo is True, perspective projections will preserve relative depth
677
+ such that Perspective = dot(Orthogonal, PseudoPerspective).
678
+
679
+ >>> P = projection_matrix((0, 0, 0), (1, 0, 0))
680
+ >>> numpy.allclose(P[1:, 1:], numpy.identity(4)[1:, 1:])
681
+ True
682
+ >>> point = numpy.random.random(3) - 0.5
683
+ >>> normal = numpy.random.random(3) - 0.5
684
+ >>> direct = numpy.random.random(3) - 0.5
685
+ >>> persp = numpy.random.random(3) - 0.5
686
+ >>> P0 = projection_matrix(point, normal)
687
+ >>> P1 = projection_matrix(point, normal, direction=direct)
688
+ >>> P2 = projection_matrix(point, normal, perspective=persp)
689
+ >>> P3 = projection_matrix(point, normal, perspective=persp, pseudo=True)
690
+ >>> is_same_transform(P2, numpy.dot(P0, P3))
691
+ True
692
+ >>> P = projection_matrix((3, 0, 0), (1, 1, 0), (1, 0, 0))
693
+ >>> v0 = (numpy.random.rand(4, 5) - 0.5) * 20.0
694
+ >>> v0[3] = 1.0
695
+ >>> v1 = numpy.dot(P, v0)
696
+ >>> numpy.allclose(v1[1], v0[1])
697
+ True
698
+ >>> numpy.allclose(v1[0], 3.0-v1[1])
699
+ True
700
+
701
+ """
702
+ M = numpy.identity(4)
703
+ point = numpy.array(point[:3], dtype=numpy.float64, copy=False)
704
+ normal = unit_vector(normal[:3])
705
+ if perspective is not None:
706
+ # perspective projection
707
+ perspective = numpy.array(perspective[:3], dtype=numpy.float64,
708
+ copy=False)
709
+ M[0, 0] = M[1, 1] = M[2, 2] = numpy.dot(perspective-point, normal)
710
+ M[:3, :3] -= numpy.outer(perspective, normal)
711
+ if pseudo:
712
+ # preserve relative depth
713
+ M[:3, :3] -= numpy.outer(normal, normal)
714
+ M[:3, 3] = numpy.dot(point, normal) * (perspective+normal)
715
+ else:
716
+ M[:3, 3] = numpy.dot(point, normal) * perspective
717
+ M[3, :3] = -normal
718
+ M[3, 3] = numpy.dot(perspective, normal)
719
+ elif direction is not None:
720
+ # parallel projection
721
+ direction = numpy.array(direction[:3], dtype=numpy.float64, copy=False)
722
+ scale = numpy.dot(direction, normal)
723
+ M[:3, :3] -= numpy.outer(direction, normal) / scale
724
+ M[:3, 3] = direction * (numpy.dot(point, normal) / scale)
725
+ else:
726
+ # orthogonal projection
727
+ M[:3, :3] -= numpy.outer(normal, normal)
728
+ M[:3, 3] = numpy.dot(point, normal) * normal
729
+ return M
730
+
731
+
732
+ def projection_from_matrix(matrix, pseudo=False):
733
+ """Return projection plane and perspective point from projection matrix.
734
+
735
+ Return values are same as arguments for projection_matrix function:
736
+ point, normal, direction, perspective, and pseudo.
737
+
738
+ >>> point = numpy.random.random(3) - 0.5
739
+ >>> normal = numpy.random.random(3) - 0.5
740
+ >>> direct = numpy.random.random(3) - 0.5
741
+ >>> persp = numpy.random.random(3) - 0.5
742
+ >>> P0 = projection_matrix(point, normal)
743
+ >>> result = projection_from_matrix(P0)
744
+ >>> P1 = projection_matrix(*result)
745
+ >>> is_same_transform(P0, P1)
746
+ True
747
+ >>> P0 = projection_matrix(point, normal, direct)
748
+ >>> result = projection_from_matrix(P0)
749
+ >>> P1 = projection_matrix(*result)
750
+ >>> is_same_transform(P0, P1)
751
+ True
752
+ >>> P0 = projection_matrix(point, normal, perspective=persp, pseudo=False)
753
+ >>> result = projection_from_matrix(P0, pseudo=False)
754
+ >>> P1 = projection_matrix(*result)
755
+ >>> is_same_transform(P0, P1)
756
+ True
757
+ >>> P0 = projection_matrix(point, normal, perspective=persp, pseudo=True)
758
+ >>> result = projection_from_matrix(P0, pseudo=True)
759
+ >>> P1 = projection_matrix(*result)
760
+ >>> is_same_transform(P0, P1)
761
+ True
762
+
763
+ """
764
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)
765
+ M33 = M[:3, :3]
766
+ l, V = numpy.linalg.eig(M)
767
+ i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-8)[0]
768
+ if not pseudo and len(i):
769
+ # point: any eigenvector corresponding to eigenvalue 1
770
+ point = numpy.real(V[:, i[-1]]).squeeze()
771
+ point /= point[3]
772
+ # direction: unit eigenvector corresponding to eigenvalue 0
773
+ l, V = numpy.linalg.eig(M33)
774
+ i = numpy.where(abs(numpy.real(l)) < 1e-8)[0]
775
+ if not len(i):
776
+ raise ValueError("no eigenvector corresponding to eigenvalue 0")
777
+ direction = numpy.real(V[:, i[0]]).squeeze()
778
+ direction /= vector_norm(direction)
779
+ # normal: unit eigenvector of M33.T corresponding to eigenvalue 0
780
+ l, V = numpy.linalg.eig(M33.T)
781
+ i = numpy.where(abs(numpy.real(l)) < 1e-8)[0]
782
+ if len(i):
783
+ # parallel projection
784
+ normal = numpy.real(V[:, i[0]]).squeeze()
785
+ normal /= vector_norm(normal)
786
+ return point, normal, direction, None, False
787
+ else:
788
+ # orthogonal projection, where normal equals direction vector
789
+ return point, direction, None, None, False
790
+ else:
791
+ # perspective projection
792
+ i = numpy.where(abs(numpy.real(l)) > 1e-8)[0]
793
+ if not len(i):
794
+ raise ValueError(
795
+ "no eigenvector not corresponding to eigenvalue 0")
796
+ point = numpy.real(V[:, i[-1]]).squeeze()
797
+ point /= point[3]
798
+ normal = - M[3, :3]
799
+ perspective = M[:3, 3] / numpy.dot(point[:3], normal)
800
+ if pseudo:
801
+ perspective -= normal
802
+ return point, normal, None, perspective, pseudo
803
+
804
+
805
+ def clip_matrix(left, right, bottom, top, near, far, perspective=False):
806
+ """Return matrix to obtain normalized device coordinates from frustrum.
807
+
808
+ The frustrum bounds are axis-aligned along x (left, right),
809
+ y (bottom, top) and z (near, far).
810
+
811
+ Normalized device coordinates are in range [-1, 1] if coordinates are
812
+ inside the frustrum.
813
+
814
+ If perspective is True the frustrum is a truncated pyramid with the
815
+ perspective point at origin and direction along z axis, otherwise an
816
+ orthographic canonical view volume (a box).
817
+
818
+ Homogeneous coordinates transformed by the perspective clip matrix
819
+ need to be dehomogenized (devided by w coordinate).
820
+
821
+ >>> frustrum = numpy.random.rand(6)
822
+ >>> frustrum[1] += frustrum[0]
823
+ >>> frustrum[3] += frustrum[2]
824
+ >>> frustrum[5] += frustrum[4]
825
+ >>> M = clip_matrix(*frustrum, perspective=False)
826
+ >>> numpy.dot(M, [frustrum[0], frustrum[2], frustrum[4], 1.0])
827
+ array([-1., -1., -1., 1.])
828
+ >>> numpy.dot(M, [frustrum[1], frustrum[3], frustrum[5], 1.0])
829
+ array([ 1., 1., 1., 1.])
830
+ >>> M = clip_matrix(*frustrum, perspective=True)
831
+ >>> v = numpy.dot(M, [frustrum[0], frustrum[2], frustrum[4], 1.0])
832
+ >>> v / v[3]
833
+ array([-1., -1., -1., 1.])
834
+ >>> v = numpy.dot(M, [frustrum[1], frustrum[3], frustrum[4], 1.0])
835
+ >>> v / v[3]
836
+ array([ 1., 1., -1., 1.])
837
+
838
+ """
839
+ if left >= right or bottom >= top or near >= far:
840
+ raise ValueError("invalid frustrum")
841
+ if perspective:
842
+ if near <= _EPS:
843
+ raise ValueError("invalid frustrum: near <= 0")
844
+ t = 2.0 * near
845
+ M = ((-t/(right-left), 0.0, (right+left)/(right-left), 0.0),
846
+ (0.0, -t/(top-bottom), (top+bottom)/(top-bottom), 0.0),
847
+ (0.0, 0.0, -(far+near)/(far-near), t*far/(far-near)),
848
+ (0.0, 0.0, -1.0, 0.0))
849
+ else:
850
+ M = ((2.0/(right-left), 0.0, 0.0, (right+left)/(left-right)),
851
+ (0.0, 2.0/(top-bottom), 0.0, (top+bottom)/(bottom-top)),
852
+ (0.0, 0.0, 2.0/(far-near), (far+near)/(near-far)),
853
+ (0.0, 0.0, 0.0, 1.0))
854
+ return numpy.array(M, dtype=numpy.float64)
855
+
856
+
857
+ def shear_matrix(angle, direction, point, normal):
858
+ """Return matrix to shear by angle along direction vector on shear plane.
859
+
860
+ The shear plane is defined by a point and normal vector. The direction
861
+ vector must be orthogonal to the plane's normal vector.
862
+
863
+ A point P is transformed by the shear matrix into P" such that
864
+ the vector P-P" is parallel to the direction vector and its extent is
865
+ given by the angle of P-P'-P", where P' is the orthogonal projection
866
+ of P onto the shear plane.
867
+
868
+ >>> angle = (random.random() - 0.5) * 4*math.pi
869
+ >>> direct = numpy.random.random(3) - 0.5
870
+ >>> point = numpy.random.random(3) - 0.5
871
+ >>> normal = numpy.cross(direct, numpy.random.random(3))
872
+ >>> S = shear_matrix(angle, direct, point, normal)
873
+ >>> numpy.allclose(1.0, numpy.linalg.det(S))
874
+ True
875
+
876
+ """
877
+ normal = unit_vector(normal[:3])
878
+ direction = unit_vector(direction[:3])
879
+ if abs(numpy.dot(normal, direction)) > 1e-6:
880
+ raise ValueError("direction and normal vectors are not orthogonal")
881
+ angle = math.tan(angle)
882
+ M = numpy.identity(4)
883
+ M[:3, :3] += angle * numpy.outer(direction, normal)
884
+ M[:3, 3] = -angle * numpy.dot(point[:3], normal) * direction
885
+ return M
886
+
887
+
888
+ def shear_from_matrix(matrix):
889
+ """Return shear angle, direction and plane from shear matrix.
890
+
891
+ >>> angle = (random.random() - 0.5) * 4*math.pi
892
+ >>> direct = numpy.random.random(3) - 0.5
893
+ >>> point = numpy.random.random(3) - 0.5
894
+ >>> normal = numpy.cross(direct, numpy.random.random(3))
895
+ >>> S0 = shear_matrix(angle, direct, point, normal)
896
+ >>> angle, direct, point, normal = shear_from_matrix(S0)
897
+ >>> S1 = shear_matrix(angle, direct, point, normal)
898
+ >>> is_same_transform(S0, S1)
899
+ True
900
+
901
+ """
902
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)
903
+ M33 = M[:3, :3]
904
+ # normal: cross independent eigenvectors corresponding to the eigenvalue 1
905
+ l, V = numpy.linalg.eig(M33)
906
+ i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-4)[0]
907
+ if len(i) < 2:
908
+ raise ValueError("No two linear independent eigenvectors found %s" % l)
909
+ V = numpy.real(V[:, i]).squeeze().T
910
+ lenorm = -1.0
911
+ for i0, i1 in ((0, 1), (0, 2), (1, 2)):
912
+ n = numpy.cross(V[i0], V[i1])
913
+ l = vector_norm(n)
914
+ if l > lenorm:
915
+ lenorm = l
916
+ normal = n
917
+ normal /= lenorm
918
+ # direction and angle
919
+ direction = numpy.dot(M33 - numpy.identity(3), normal)
920
+ angle = vector_norm(direction)
921
+ direction /= angle
922
+ angle = math.atan(angle)
923
+ # point: eigenvector corresponding to eigenvalue 1
924
+ l, V = numpy.linalg.eig(M)
925
+ i = numpy.where(abs(numpy.real(l) - 1.0) < 1e-8)[0]
926
+ if not len(i):
927
+ raise ValueError("no eigenvector corresponding to eigenvalue 1")
928
+ point = numpy.real(V[:, i[-1]]).squeeze()
929
+ point /= point[3]
930
+ return angle, direction, point, normal
931
+
932
+
933
+ def decompose_matrix(matrix):
934
+ """Return sequence of transformations from transformation matrix.
935
+
936
+ matrix : array_like
937
+ Non-degenerative homogeneous transformation matrix
938
+
939
+ Return tuple of:
940
+ scale : vector of 3 scaling factors
941
+ shear : list of shear factors for x-y, x-z, y-z axes
942
+ angles : list of Euler angles about static x, y, z axes
943
+ translate : translation vector along x, y, z axes
944
+ perspective : perspective partition of matrix
945
+
946
+ Raise ValueError if matrix is of wrong type or degenerative.
947
+
948
+ >>> T0 = translation_matrix((1, 2, 3))
949
+ >>> scale, shear, angles, trans, persp = decompose_matrix(T0)
950
+ >>> T1 = translation_matrix(trans)
951
+ >>> numpy.allclose(T0, T1)
952
+ True
953
+ >>> S = scale_matrix(0.123)
954
+ >>> scale, shear, angles, trans, persp = decompose_matrix(S)
955
+ >>> scale[0]
956
+ 0.123
957
+ >>> R0 = euler_matrix(1, 2, 3)
958
+ >>> scale, shear, angles, trans, persp = decompose_matrix(R0)
959
+ >>> R1 = euler_matrix(*angles)
960
+ >>> numpy.allclose(R0, R1)
961
+ True
962
+
963
+ """
964
+ M = numpy.array(matrix, dtype=numpy.float64, copy=True).T
965
+ if abs(M[3, 3]) < _EPS:
966
+ raise ValueError("M[3, 3] is zero")
967
+ M /= M[3, 3]
968
+ P = M.copy()
969
+ P[:, 3] = 0, 0, 0, 1
970
+ if not numpy.linalg.det(P):
971
+ raise ValueError("Matrix is singular")
972
+
973
+ scale = numpy.zeros((3, ), dtype=numpy.float64)
974
+ shear = [0, 0, 0]
975
+ angles = [0, 0, 0]
976
+
977
+ if any(abs(M[:3, 3]) > _EPS):
978
+ perspective = numpy.dot(M[:, 3], numpy.linalg.inv(P.T))
979
+ M[:, 3] = 0, 0, 0, 1
980
+ else:
981
+ perspective = numpy.array((0, 0, 0, 1), dtype=numpy.float64)
982
+
983
+ translate = M[3, :3].copy()
984
+ M[3, :3] = 0
985
+
986
+ row = M[:3, :3].copy()
987
+ scale[0] = vector_norm(row[0])
988
+ row[0] /= scale[0]
989
+ shear[0] = numpy.dot(row[0], row[1])
990
+ row[1] -= row[0] * shear[0]
991
+ scale[1] = vector_norm(row[1])
992
+ row[1] /= scale[1]
993
+ shear[0] /= scale[1]
994
+ shear[1] = numpy.dot(row[0], row[2])
995
+ row[2] -= row[0] * shear[1]
996
+ shear[2] = numpy.dot(row[1], row[2])
997
+ row[2] -= row[1] * shear[2]
998
+ scale[2] = vector_norm(row[2])
999
+ row[2] /= scale[2]
1000
+ shear[1:] /= scale[2]
1001
+
1002
+ if numpy.dot(row[0], numpy.cross(row[1], row[2])) < 0:
1003
+ scale *= -1
1004
+ row *= -1
1005
+
1006
+ angles[1] = math.asin(-row[0, 2])
1007
+ if math.cos(angles[1]):
1008
+ angles[0] = math.atan2(row[1, 2], row[2, 2])
1009
+ angles[2] = math.atan2(row[0, 1], row[0, 0])
1010
+ else:
1011
+ #angles[0] = math.atan2(row[1, 0], row[1, 1])
1012
+ angles[0] = math.atan2(-row[2, 1], row[1, 1])
1013
+ angles[2] = 0.0
1014
+
1015
+ return scale, shear, angles, translate, perspective
1016
+
1017
+
1018
+ def compose_matrix(scale=None, shear=None, angles=None, translate=None,
1019
+ perspective=None):
1020
+ """Return transformation matrix from sequence of transformations.
1021
+
1022
+ This is the inverse of the decompose_matrix function.
1023
+
1024
+ Sequence of transformations:
1025
+ scale : vector of 3 scaling factors
1026
+ shear : list of shear factors for x-y, x-z, y-z axes
1027
+ angles : list of Euler angles about static x, y, z axes
1028
+ translate : translation vector along x, y, z axes
1029
+ perspective : perspective partition of matrix
1030
+
1031
+ >>> scale = numpy.random.random(3) - 0.5
1032
+ >>> shear = numpy.random.random(3) - 0.5
1033
+ >>> angles = (numpy.random.random(3) - 0.5) * (2*math.pi)
1034
+ >>> trans = numpy.random.random(3) - 0.5
1035
+ >>> persp = numpy.random.random(4) - 0.5
1036
+ >>> M0 = compose_matrix(scale, shear, angles, trans, persp)
1037
+ >>> result = decompose_matrix(M0)
1038
+ >>> M1 = compose_matrix(*result)
1039
+ >>> is_same_transform(M0, M1)
1040
+ True
1041
+
1042
+ """
1043
+ M = numpy.identity(4)
1044
+ if perspective is not None:
1045
+ P = numpy.identity(4)
1046
+ P[3, :] = perspective[:4]
1047
+ M = numpy.dot(M, P)
1048
+ if translate is not None:
1049
+ T = numpy.identity(4)
1050
+ T[:3, 3] = translate[:3]
1051
+ M = numpy.dot(M, T)
1052
+ if angles is not None:
1053
+ R = euler_matrix(angles[0], angles[1], angles[2], 'sxyz')
1054
+ M = numpy.dot(M, R)
1055
+ if shear is not None:
1056
+ Z = numpy.identity(4)
1057
+ Z[1, 2] = shear[2]
1058
+ Z[0, 2] = shear[1]
1059
+ Z[0, 1] = shear[0]
1060
+ M = numpy.dot(M, Z)
1061
+ if scale is not None:
1062
+ S = numpy.identity(4)
1063
+ S[0, 0] = scale[0]
1064
+ S[1, 1] = scale[1]
1065
+ S[2, 2] = scale[2]
1066
+ M = numpy.dot(M, S)
1067
+ M /= M[3, 3]
1068
+ return M
1069
+
1070
+
1071
+ def orthogonalization_matrix(lengths, angles):
1072
+ """Return orthogonalization matrix for crystallographic cell coordinates.
1073
+
1074
+ Angles are expected in degrees.
1075
+
1076
+ The de-orthogonalization matrix is the inverse.
1077
+
1078
+ >>> O = orthogonalization_matrix((10., 10., 10.), (90., 90., 90.))
1079
+ >>> numpy.allclose(O[:3, :3], numpy.identity(3, float) * 10)
1080
+ True
1081
+ >>> O = orthogonalization_matrix([9.8, 12.0, 15.5], [87.2, 80.7, 69.7])
1082
+ >>> numpy.allclose(numpy.sum(O), 43.063229)
1083
+ True
1084
+
1085
+ """
1086
+ a, b, c = lengths
1087
+ angles = numpy.radians(angles)
1088
+ sina, sinb, _ = numpy.sin(angles)
1089
+ cosa, cosb, cosg = numpy.cos(angles)
1090
+ co = (cosa * cosb - cosg) / (sina * sinb)
1091
+ return numpy.array((
1092
+ (a*sinb*math.sqrt(1.0-co*co), 0.0, 0.0, 0.0),
1093
+ (-a*sinb*co, b*sina, 0.0, 0.0),
1094
+ (a*cosb, b*cosa, c, 0.0),
1095
+ (0.0, 0.0, 0.0, 1.0)),
1096
+ dtype=numpy.float64)
1097
+
1098
+
1099
+ def superimposition_matrix(v0, v1, scaling=False, usesvd=True):
1100
+ """Return matrix to transform given vector set into second vector set.
1101
+
1102
+ v0 and v1 are shape (3, \*) or (4, \*) arrays of at least 3 vectors.
1103
+
1104
+ If usesvd is True, the weighted sum of squared deviations (RMSD) is
1105
+ minimized according to the algorithm by W. Kabsch [8]. Otherwise the
1106
+ quaternion based algorithm by B. Horn [9] is used (slower when using
1107
+ this Python implementation).
1108
+
1109
+ The returned matrix performs rotation, translation and uniform scaling
1110
+ (if specified).
1111
+
1112
+ >>> v0 = numpy.random.rand(3, 10)
1113
+ >>> M = superimposition_matrix(v0, v0)
1114
+ >>> numpy.allclose(M, numpy.identity(4))
1115
+ True
1116
+ >>> R = random_rotation_matrix(numpy.random.random(3))
1117
+ >>> v0 = ((1,0,0), (0,1,0), (0,0,1), (1,1,1))
1118
+ >>> v1 = numpy.dot(R, v0)
1119
+ >>> M = superimposition_matrix(v0, v1)
1120
+ >>> numpy.allclose(v1, numpy.dot(M, v0))
1121
+ True
1122
+ >>> v0 = (numpy.random.rand(4, 100) - 0.5) * 20.0
1123
+ >>> v0[3] = 1.0
1124
+ >>> v1 = numpy.dot(R, v0)
1125
+ >>> M = superimposition_matrix(v0, v1)
1126
+ >>> numpy.allclose(v1, numpy.dot(M, v0))
1127
+ True
1128
+ >>> S = scale_matrix(random.random())
1129
+ >>> T = translation_matrix(numpy.random.random(3)-0.5)
1130
+ >>> M = concatenate_matrices(T, R, S)
1131
+ >>> v1 = numpy.dot(M, v0)
1132
+ >>> v0[:3] += numpy.random.normal(0.0, 1e-9, 300).reshape(3, -1)
1133
+ >>> M = superimposition_matrix(v0, v1, scaling=True)
1134
+ >>> numpy.allclose(v1, numpy.dot(M, v0))
1135
+ True
1136
+ >>> M = superimposition_matrix(v0, v1, scaling=True, usesvd=False)
1137
+ >>> numpy.allclose(v1, numpy.dot(M, v0))
1138
+ True
1139
+ >>> v = numpy.empty((4, 100, 3), dtype=numpy.float64)
1140
+ >>> v[:, :, 0] = v0
1141
+ >>> M = superimposition_matrix(v0, v1, scaling=True, usesvd=False)
1142
+ >>> numpy.allclose(v1, numpy.dot(M, v[:, :, 0]))
1143
+ True
1144
+
1145
+ """
1146
+ v0 = numpy.array(v0, dtype=numpy.float64, copy=False)[:3]
1147
+ v1 = numpy.array(v1, dtype=numpy.float64, copy=False)[:3]
1148
+
1149
+ if v0.shape != v1.shape or v0.shape[1] < 3:
1150
+ raise ValueError("Vector sets are of wrong shape or type.")
1151
+
1152
+ # move centroids to origin
1153
+ t0 = numpy.mean(v0, axis=1)
1154
+ t1 = numpy.mean(v1, axis=1)
1155
+ v0 = v0 - t0.reshape(3, 1)
1156
+ v1 = v1 - t1.reshape(3, 1)
1157
+
1158
+ if usesvd:
1159
+ # Singular Value Decomposition of covariance matrix
1160
+ u, s, vh = numpy.linalg.svd(numpy.dot(v1, v0.T))
1161
+ # rotation matrix from SVD orthonormal bases
1162
+ R = numpy.dot(u, vh)
1163
+ if numpy.linalg.det(R) < 0.0:
1164
+ # R does not constitute right handed system
1165
+ R -= numpy.outer(u[:, 2], vh[2, :]*2.0)
1166
+ s[-1] *= -1.0
1167
+ # homogeneous transformation matrix
1168
+ M = numpy.identity(4)
1169
+ M[:3, :3] = R
1170
+ else:
1171
+ # compute symmetric matrix N
1172
+ xx, yy, zz = numpy.sum(v0 * v1, axis=1)
1173
+ xy, yz, zx = numpy.sum(v0 * numpy.roll(v1, -1, axis=0), axis=1)
1174
+ xz, yx, zy = numpy.sum(v0 * numpy.roll(v1, -2, axis=0), axis=1)
1175
+ N = ((xx+yy+zz, yz-zy, zx-xz, xy-yx),
1176
+ (yz-zy, xx-yy-zz, xy+yx, zx+xz),
1177
+ (zx-xz, xy+yx, -xx+yy-zz, yz+zy),
1178
+ (xy-yx, zx+xz, yz+zy, -xx-yy+zz))
1179
+ # quaternion: eigenvector corresponding to most positive eigenvalue
1180
+ l, V = numpy.linalg.eig(N)
1181
+ q = V[:, numpy.argmax(l)]
1182
+ q /= vector_norm(q) # unit quaternion
1183
+ q = numpy.roll(q, -1) # move w component to end
1184
+ # homogeneous transformation matrix
1185
+ M = quaternion_matrix(q)
1186
+
1187
+ # scale: ratio of rms deviations from centroid
1188
+ if scaling:
1189
+ v0 *= v0
1190
+ v1 *= v1
1191
+ M[:3, :3] *= math.sqrt(numpy.sum(v1) / numpy.sum(v0))
1192
+
1193
+ # translation
1194
+ M[:3, 3] = t1
1195
+ T = numpy.identity(4)
1196
+ T[:3, 3] = -t0
1197
+ M = numpy.dot(M, T)
1198
+ return M
1199
+
1200
+
1201
+ def euler_matrix(ai, aj, ak, axes='sxyz'):
1202
+ """Return homogeneous rotation matrix from Euler angles and axis sequence.
1203
+
1204
+ ai, aj, ak : Euler's roll, pitch and yaw angles
1205
+ axes : One of 24 axis sequences as string or encoded tuple
1206
+
1207
+ >>> R = euler_matrix(1, 2, 3, 'syxz')
1208
+ >>> numpy.allclose(numpy.sum(R[0]), -1.34786452)
1209
+ True
1210
+ >>> R = euler_matrix(1, 2, 3, (0, 1, 0, 1))
1211
+ >>> numpy.allclose(numpy.sum(R[0]), -0.383436184)
1212
+ True
1213
+ >>> ai, aj, ak = (4.0*math.pi) * (numpy.random.random(3) - 0.5)
1214
+ >>> for axes in _AXES2TUPLE.keys():
1215
+ ... R = euler_matrix(ai, aj, ak, axes)
1216
+ >>> for axes in _TUPLE2AXES.keys():
1217
+ ... R = euler_matrix(ai, aj, ak, axes)
1218
+
1219
+ """
1220
+ try:
1221
+ firstaxis, parity, repetition, frame = _AXES2TUPLE[axes]
1222
+ except (AttributeError, KeyError):
1223
+ _ = _TUPLE2AXES[axes]
1224
+ firstaxis, parity, repetition, frame = axes
1225
+
1226
+ i = firstaxis
1227
+ j = _NEXT_AXIS[i+parity]
1228
+ k = _NEXT_AXIS[i-parity+1]
1229
+
1230
+ if frame:
1231
+ ai, ak = ak, ai
1232
+ if parity:
1233
+ ai, aj, ak = -ai, -aj, -ak
1234
+
1235
+ si, sj, sk = math.sin(ai), math.sin(aj), math.sin(ak)
1236
+ ci, cj, ck = math.cos(ai), math.cos(aj), math.cos(ak)
1237
+ cc, cs = ci*ck, ci*sk
1238
+ sc, ss = si*ck, si*sk
1239
+
1240
+ M = numpy.identity(4)
1241
+ if repetition:
1242
+ M[i, i] = cj
1243
+ M[i, j] = sj*si
1244
+ M[i, k] = sj*ci
1245
+ M[j, i] = sj*sk
1246
+ M[j, j] = -cj*ss+cc
1247
+ M[j, k] = -cj*cs-sc
1248
+ M[k, i] = -sj*ck
1249
+ M[k, j] = cj*sc+cs
1250
+ M[k, k] = cj*cc-ss
1251
+ else:
1252
+ M[i, i] = cj*ck
1253
+ M[i, j] = sj*sc-cs
1254
+ M[i, k] = sj*cc+ss
1255
+ M[j, i] = cj*sk
1256
+ M[j, j] = sj*ss+cc
1257
+ M[j, k] = sj*cs-sc
1258
+ M[k, i] = -sj
1259
+ M[k, j] = cj*si
1260
+ M[k, k] = cj*ci
1261
+ return M
1262
+
1263
+
1264
+ def euler_from_matrix(matrix, axes='sxyz'):
1265
+ """Return Euler angles from rotation matrix for specified axis sequence.
1266
+
1267
+ axes : One of 24 axis sequences as string or encoded tuple
1268
+
1269
+ Note that many Euler angle triplets can describe one matrix.
1270
+
1271
+ >>> R0 = euler_matrix(1, 2, 3, 'syxz')
1272
+ >>> al, be, ga = euler_from_matrix(R0, 'syxz')
1273
+ >>> R1 = euler_matrix(al, be, ga, 'syxz')
1274
+ >>> numpy.allclose(R0, R1)
1275
+ True
1276
+ >>> angles = (4.0*math.pi) * (numpy.random.random(3) - 0.5)
1277
+ >>> for axes in _AXES2TUPLE.keys():
1278
+ ... R0 = euler_matrix(axes=axes, *angles)
1279
+ ... R1 = euler_matrix(axes=axes, *euler_from_matrix(R0, axes))
1280
+ ... if not numpy.allclose(R0, R1): print axes, "failed"
1281
+
1282
+ """
1283
+ try:
1284
+ firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()]
1285
+ except (AttributeError, KeyError):
1286
+ _ = _TUPLE2AXES[axes]
1287
+ firstaxis, parity, repetition, frame = axes
1288
+
1289
+ i = firstaxis
1290
+ j = _NEXT_AXIS[i+parity]
1291
+ k = _NEXT_AXIS[i-parity+1]
1292
+
1293
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:3, :3]
1294
+ if repetition:
1295
+ sy = math.sqrt(M[i, j]*M[i, j] + M[i, k]*M[i, k])
1296
+ if sy > _EPS:
1297
+ ax = math.atan2(M[i, j], M[i, k])
1298
+ ay = math.atan2(sy, M[i, i])
1299
+ az = math.atan2(M[j, i], -M[k, i])
1300
+ else:
1301
+ ax = math.atan2(-M[j, k], M[j, j])
1302
+ ay = math.atan2(sy, M[i, i])
1303
+ az = 0.0
1304
+ else:
1305
+ cy = math.sqrt(M[i, i]*M[i, i] + M[j, i]*M[j, i])
1306
+ if cy > _EPS:
1307
+ ax = math.atan2(M[k, j], M[k, k])
1308
+ ay = math.atan2(-M[k, i], cy)
1309
+ az = math.atan2(M[j, i], M[i, i])
1310
+ else:
1311
+ ax = math.atan2(-M[j, k], M[j, j])
1312
+ ay = math.atan2(-M[k, i], cy)
1313
+ az = 0.0
1314
+
1315
+ if parity:
1316
+ ax, ay, az = -ax, -ay, -az
1317
+ if frame:
1318
+ ax, az = az, ax
1319
+ return ax, ay, az
1320
+
1321
+
1322
+ def euler_from_quaternion(quaternion, axes='sxyz'):
1323
+ """Return Euler angles from quaternion for specified axis sequence.
1324
+
1325
+ >>> angles = euler_from_quaternion([0.06146124, 0, 0, 0.99810947])
1326
+ >>> numpy.allclose(angles, [0.123, 0, 0])
1327
+ True
1328
+
1329
+ """
1330
+ return euler_from_matrix(quaternion_matrix(quaternion), axes)
1331
+
1332
+
1333
+ def quaternion_from_euler(ai, aj, ak, axes='sxyz'):
1334
+ """Return quaternion from Euler angles and axis sequence.
1335
+
1336
+ ai, aj, ak : Euler's roll, pitch and yaw angles
1337
+ axes : One of 24 axis sequences as string or encoded tuple
1338
+
1339
+ >>> q = quaternion_from_euler(1, 2, 3, 'ryxz')
1340
+ >>> numpy.allclose(q, [0.310622, -0.718287, 0.444435, 0.435953])
1341
+ True
1342
+
1343
+ """
1344
+ try:
1345
+ firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()]
1346
+ except (AttributeError, KeyError):
1347
+ _ = _TUPLE2AXES[axes]
1348
+ firstaxis, parity, repetition, frame = axes
1349
+
1350
+ i = firstaxis
1351
+ j = _NEXT_AXIS[i+parity]
1352
+ k = _NEXT_AXIS[i-parity+1]
1353
+
1354
+ if frame:
1355
+ ai, ak = ak, ai
1356
+ if parity:
1357
+ aj = -aj
1358
+
1359
+ ai /= 2.0
1360
+ aj /= 2.0
1361
+ ak /= 2.0
1362
+ ci = math.cos(ai)
1363
+ si = math.sin(ai)
1364
+ cj = math.cos(aj)
1365
+ sj = math.sin(aj)
1366
+ ck = math.cos(ak)
1367
+ sk = math.sin(ak)
1368
+ cc = ci*ck
1369
+ cs = ci*sk
1370
+ sc = si*ck
1371
+ ss = si*sk
1372
+
1373
+ quaternion = numpy.empty((4, ), dtype=numpy.float64)
1374
+ if repetition:
1375
+ quaternion[i] = cj*(cs + sc)
1376
+ quaternion[j] = sj*(cc + ss)
1377
+ quaternion[k] = sj*(cs - sc)
1378
+ quaternion[3] = cj*(cc - ss)
1379
+ else:
1380
+ quaternion[i] = cj*sc - sj*cs
1381
+ quaternion[j] = cj*ss + sj*cc
1382
+ quaternion[k] = cj*cs - sj*sc
1383
+ quaternion[3] = cj*cc + sj*ss
1384
+ if parity:
1385
+ quaternion[j] *= -1
1386
+
1387
+ return quaternion
1388
+
1389
+
1390
+ def quaternion_about_axis(angle, axis):
1391
+ """Return quaternion for rotation about axis.
1392
+
1393
+ >>> q = quaternion_about_axis(0.123, (1, 0, 0))
1394
+ >>> numpy.allclose(q, [0.06146124, 0, 0, 0.99810947])
1395
+ True
1396
+
1397
+ """
1398
+ quaternion = numpy.zeros((4, ), dtype=numpy.float64)
1399
+ quaternion[:3] = axis[:3]
1400
+ qlen = vector_norm(quaternion)
1401
+ if qlen > _EPS:
1402
+ quaternion *= math.sin(angle/2.0) / qlen
1403
+ quaternion[3] = math.cos(angle/2.0)
1404
+ return quaternion
1405
+
1406
+
1407
+ def matrix_from_quaternion(quaternion):
1408
+ return quaternion_matrix(quaternion)
1409
+
1410
+
1411
+ def quaternion_matrix(quaternion):
1412
+ """Return homogeneous rotation matrix from quaternion.
1413
+
1414
+ >>> R = quaternion_matrix([0.06146124, 0, 0, 0.99810947])
1415
+ >>> numpy.allclose(R, rotation_matrix(0.123, (1, 0, 0)))
1416
+ True
1417
+
1418
+ """
1419
+ q = numpy.array(quaternion[:4], dtype=numpy.float64, copy=True)
1420
+ nq = numpy.dot(q, q)
1421
+ if nq < _EPS:
1422
+ return numpy.identity(4)
1423
+ q *= math.sqrt(2.0 / nq)
1424
+ q = numpy.outer(q, q)
1425
+ return numpy.array((
1426
+ (1.0-q[1, 1]-q[2, 2], q[0, 1]-q[2, 3], q[0, 2]+q[1, 3], 0.0),
1427
+ (q[0, 1]+q[2, 3], 1.0-q[0, 0]-q[2, 2], q[1, 2]-q[0, 3], 0.0),
1428
+ (q[0, 2]-q[1, 3], q[1, 2]+q[0, 3], 1.0-q[0, 0]-q[1, 1], 0.0),
1429
+ (0.0, 0.0, 0.0, 1.0)
1430
+ ), dtype=numpy.float64)
1431
+
1432
+
1433
+ def quaternionJPL_matrix(quaternion):
1434
+ """Return homogeneous rotation matrix from quaternion in JPL notation.
1435
+ quaternion = [x y z w]
1436
+ """
1437
+ q0 = quaternion[0]
1438
+ q1 = quaternion[1]
1439
+ q2 = quaternion[2]
1440
+ q3 = quaternion[3]
1441
+ return numpy.array([
1442
+ [q0**2 - q1**2 - q2**2 + q3**2, 2.0*q0*q1 +
1443
+ 2.0*q2*q3, 2.0*q0*q2 - 2.0*q1*q3, 0],
1444
+ [2.0*q0*q1 - 2.0*q2*q3, - q0**2 + q1**2 -
1445
+ q2**2 + q3**2, 2.0*q0*q3 + 2.0*q1*q2, 0],
1446
+ [2.0*q0*q2 + 2.0*q1*q3, 2.0*q1*q2 - 2.0*q0 *
1447
+ q3, - q0**2 - q1**2 + q2**2 + q3**2, 0],
1448
+ [0, 0, 0, 1.0]], dtype=numpy.float64)
1449
+
1450
+
1451
+ def quaternion_from_matrix(matrix):
1452
+ """Return quaternion from rotation matrix.
1453
+
1454
+ >>> R = rotation_matrix(0.123, (1, 2, 3))
1455
+ >>> q = quaternion_from_matrix(R)
1456
+ >>> numpy.allclose(q, [0.0164262, 0.0328524, 0.0492786, 0.9981095])
1457
+ True
1458
+
1459
+ """
1460
+ q = numpy.empty((4, ), dtype=numpy.float64)
1461
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:4, :4]
1462
+ t = numpy.trace(M)
1463
+ if t > M[3, 3]:
1464
+ q[3] = t
1465
+ q[2] = M[1, 0] - M[0, 1]
1466
+ q[1] = M[0, 2] - M[2, 0]
1467
+ q[0] = M[2, 1] - M[1, 2]
1468
+ else:
1469
+ i, j, k = 0, 1, 2
1470
+ if M[1, 1] > M[0, 0]:
1471
+ i, j, k = 1, 2, 0
1472
+ if M[2, 2] > M[i, i]:
1473
+ i, j, k = 2, 0, 1
1474
+ t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3]
1475
+ q[i] = t
1476
+ q[j] = M[i, j] + M[j, i]
1477
+ q[k] = M[k, i] + M[i, k]
1478
+ q[3] = M[k, j] - M[j, k]
1479
+ q *= 0.5 / math.sqrt(t * M[3, 3])
1480
+ return q
1481
+
1482
+
1483
+ def quaternion_multiply(quaternion1, quaternion0):
1484
+ """Return multiplication of two quaternions.
1485
+
1486
+ >>> q = quaternion_multiply([1, -2, 3, 4], [-5, 6, 7, 8])
1487
+ >>> numpy.allclose(q, [-44, -14, 48, 28])
1488
+ True
1489
+
1490
+ """
1491
+ x0, y0, z0, w0 = quaternion0
1492
+ x1, y1, z1, w1 = quaternion1
1493
+ return numpy.array((
1494
+ x1*w0 + y1*z0 - z1*y0 + w1*x0,
1495
+ -x1*z0 + y1*w0 + z1*x0 + w1*y0,
1496
+ x1*y0 - y1*x0 + z1*w0 + w1*z0,
1497
+ -x1*x0 - y1*y0 - z1*z0 + w1*w0), dtype=numpy.float64)
1498
+
1499
+
1500
+ def quaternion_conjugate(quaternion):
1501
+ """Return conjugate of quaternion.
1502
+
1503
+ >>> q0 = random_quaternion()
1504
+ >>> q1 = quaternion_conjugate(q0)
1505
+ >>> q1[3] == q0[3] and all(q1[:3] == -q0[:3])
1506
+ True
1507
+
1508
+ """
1509
+ return numpy.array((-quaternion[0], -quaternion[1],
1510
+ -quaternion[2], quaternion[3]), dtype=numpy.float64)
1511
+
1512
+
1513
+ def quaternion_inverse(quaternion):
1514
+ """Return inverse of quaternion.
1515
+
1516
+ >>> q0 = random_quaternion()
1517
+ >>> q1 = quaternion_inverse(q0)
1518
+ >>> numpy.allclose(quaternion_multiply(q0, q1), [0, 0, 0, 1])
1519
+ True
1520
+
1521
+ """
1522
+ return quaternion_conjugate(quaternion) / numpy.dot(quaternion, quaternion)
1523
+
1524
+
1525
+ def quaternion_slerp(quat0, quat1, fraction, spin=0, shortestpath=True):
1526
+ """Return spherical linear interpolation between two quaternions.
1527
+
1528
+ >>> q0 = random_quaternion()
1529
+ >>> q1 = random_quaternion()
1530
+ >>> q = quaternion_slerp(q0, q1, 0.0)
1531
+ >>> numpy.allclose(q, q0)
1532
+ True
1533
+ >>> q = quaternion_slerp(q0, q1, 1.0, 1)
1534
+ >>> numpy.allclose(q, q1)
1535
+ True
1536
+ >>> q = quaternion_slerp(q0, q1, 0.5)
1537
+ >>> angle = math.acos(numpy.dot(q0, q))
1538
+ >>> numpy.allclose(2.0, math.acos(numpy.dot(q0, q1)) / angle) or \
1539
+ numpy.allclose(2.0, math.acos(-numpy.dot(q0, q1)) / angle)
1540
+ True
1541
+
1542
+ """
1543
+ q0 = unit_vector(quat0[:4])
1544
+ q1 = unit_vector(quat1[:4])
1545
+ if fraction == 0.0:
1546
+ return q0
1547
+ elif fraction == 1.0:
1548
+ return q1
1549
+ d = numpy.dot(q0, q1)
1550
+ if abs(abs(d) - 1.0) < _EPS:
1551
+ return q0
1552
+ if shortestpath and d < 0.0:
1553
+ # invert rotation
1554
+ d = -d
1555
+ q1 *= -1.0
1556
+ angle = math.acos(d) + spin * math.pi
1557
+ if abs(angle) < _EPS:
1558
+ return q0
1559
+ isin = 1.0 / math.sin(angle)
1560
+ q0 *= math.sin((1.0 - fraction) * angle) * isin
1561
+ q1 *= math.sin(fraction * angle) * isin
1562
+ q0 += q1
1563
+ return q0
1564
+
1565
+
1566
+ def random_quaternion(rand=None):
1567
+ """Return uniform random unit quaternion.
1568
+
1569
+ rand: array like or None
1570
+ Three independent random variables that are uniformly distributed
1571
+ between 0 and 1.
1572
+
1573
+ >>> q = random_quaternion()
1574
+ >>> numpy.allclose(1.0, vector_norm(q))
1575
+ True
1576
+ >>> q = random_quaternion(numpy.random.random(3))
1577
+ >>> q.shape
1578
+ (4,)
1579
+
1580
+ """
1581
+ if rand is None:
1582
+ rand = numpy.random.rand(3)
1583
+ else:
1584
+ assert len(rand) == 3
1585
+ r1 = numpy.sqrt(1.0 - rand[0])
1586
+ r2 = numpy.sqrt(rand[0])
1587
+ pi2 = math.pi * 2.0
1588
+ t1 = pi2 * rand[1]
1589
+ t2 = pi2 * rand[2]
1590
+ return numpy.array((numpy.sin(t1)*r1,
1591
+ numpy.cos(t1)*r1,
1592
+ numpy.sin(t2)*r2,
1593
+ numpy.cos(t2)*r2), dtype=numpy.float64)
1594
+
1595
+
1596
+ def random_rotation_matrix(rand=None):
1597
+ """Return uniform random rotation matrix.
1598
+
1599
+ rnd: array like
1600
+ Three independent random variables that are uniformly distributed
1601
+ between 0 and 1 for each returned quaternion.
1602
+
1603
+ >>> R = random_rotation_matrix()
1604
+ >>> numpy.allclose(numpy.dot(R.T, R), numpy.identity(4))
1605
+ True
1606
+
1607
+ """
1608
+ return quaternion_matrix(random_quaternion(rand))
1609
+
1610
+
1611
+ def random_direction_3d():
1612
+ """ equal-area projection according to:
1613
+ https://math.stackexchange.com/questions/44689/how-to-find-a-random-axis-or-unit-vector-in-3d
1614
+ cfo, 2015/10/16
1615
+ """
1616
+ z = numpy.random.rand() * 2.0 - 1.0
1617
+ t = numpy.random.rand() * 2.0 * numpy.pi
1618
+ r = numpy.sqrt(1.0 - z*z)
1619
+ x = r * numpy.cos(t)
1620
+ y = r * numpy.sin(t)
1621
+ return numpy.array([x, y, z], dtype=numpy.float64)
1622
+
1623
+
1624
+ class Arcball(object):
1625
+ """Virtual Trackball Control.
1626
+
1627
+ >>> ball = Arcball()
1628
+ >>> ball = Arcball(initial=numpy.identity(4))
1629
+ >>> ball.place([320, 320], 320)
1630
+ >>> ball.down([500, 250])
1631
+ >>> ball.drag([475, 275])
1632
+ >>> R = ball.matrix()
1633
+ >>> numpy.allclose(numpy.sum(R), 3.90583455)
1634
+ True
1635
+ >>> ball = Arcball(initial=[0, 0, 0, 1])
1636
+ >>> ball.place([320, 320], 320)
1637
+ >>> ball.setaxes([1,1,0], [-1, 1, 0])
1638
+ >>> ball.setconstrain(True)
1639
+ >>> ball.down([400, 200])
1640
+ >>> ball.drag([200, 400])
1641
+ >>> R = ball.matrix()
1642
+ >>> numpy.allclose(numpy.sum(R), 0.2055924)
1643
+ True
1644
+ >>> ball.next()
1645
+
1646
+ """
1647
+
1648
+ def __init__(self, initial=None):
1649
+ """Initialize virtual trackball control.
1650
+
1651
+ initial : quaternion or rotation matrix
1652
+
1653
+ """
1654
+ self._axis = None
1655
+ self._axes = None
1656
+ self._radius = 1.0
1657
+ self._center = [0.0, 0.0]
1658
+ self._vdown = numpy.array([0, 0, 1], dtype=numpy.float64)
1659
+ self._constrain = False
1660
+
1661
+ if initial is None:
1662
+ self._qdown = numpy.array([0, 0, 0, 1], dtype=numpy.float64)
1663
+ else:
1664
+ initial = numpy.array(initial, dtype=numpy.float64)
1665
+ if initial.shape == (4, 4):
1666
+ self._qdown = quaternion_from_matrix(initial)
1667
+ elif initial.shape == (4, ):
1668
+ initial /= vector_norm(initial)
1669
+ self._qdown = initial
1670
+ else:
1671
+ raise ValueError("initial not a quaternion or matrix.")
1672
+
1673
+ self._qnow = self._qpre = self._qdown
1674
+
1675
+ def place(self, center, radius):
1676
+ """Place Arcball, e.g. when window size changes.
1677
+
1678
+ center : sequence[2]
1679
+ Window coordinates of trackball center.
1680
+ radius : float
1681
+ Radius of trackball in window coordinates.
1682
+
1683
+ """
1684
+ self._radius = float(radius)
1685
+ self._center[0] = center[0]
1686
+ self._center[1] = center[1]
1687
+
1688
+ def setaxes(self, *axes):
1689
+ """Set axes to constrain rotations."""
1690
+ if axes is None:
1691
+ self._axes = None
1692
+ else:
1693
+ self._axes = [unit_vector(axis) for axis in axes]
1694
+
1695
+ def setconstrain(self, constrain):
1696
+ """Set state of constrain to axis mode."""
1697
+ self._constrain = constrain == True
1698
+
1699
+ def getconstrain(self):
1700
+ """Return state of constrain to axis mode."""
1701
+ return self._constrain
1702
+
1703
+ def down(self, point):
1704
+ """Set initial cursor window coordinates and pick constrain-axis."""
1705
+ self._vdown = arcball_map_to_sphere(point, self._center, self._radius)
1706
+ self._qdown = self._qpre = self._qnow
1707
+
1708
+ if self._constrain and self._axes is not None:
1709
+ self._axis = arcball_nearest_axis(self._vdown, self._axes)
1710
+ self._vdown = arcball_constrain_to_axis(self._vdown, self._axis)
1711
+ else:
1712
+ self._axis = None
1713
+
1714
+ def drag(self, point):
1715
+ """Update current cursor window coordinates."""
1716
+ vnow = arcball_map_to_sphere(point, self._center, self._radius)
1717
+
1718
+ if self._axis is not None:
1719
+ vnow = arcball_constrain_to_axis(vnow, self._axis)
1720
+
1721
+ self._qpre = self._qnow
1722
+
1723
+ t = numpy.cross(self._vdown, vnow)
1724
+ if numpy.dot(t, t) < _EPS:
1725
+ self._qnow = self._qdown
1726
+ else:
1727
+ q = [t[0], t[1], t[2], numpy.dot(self._vdown, vnow)]
1728
+ self._qnow = quaternion_multiply(q, self._qdown)
1729
+
1730
+ def next(self, acceleration=0.0):
1731
+ """Continue rotation in direction of last drag."""
1732
+ q = quaternion_slerp(self._qpre, self._qnow, 2.0+acceleration, False)
1733
+ self._qpre, self._qnow = self._qnow, q
1734
+
1735
+ def matrix(self):
1736
+ """Return homogeneous rotation matrix."""
1737
+ return quaternion_matrix(self._qnow)
1738
+
1739
+
1740
+ def arcball_map_to_sphere(point, center, radius):
1741
+ """Return unit sphere coordinates from window coordinates."""
1742
+ v = numpy.array(((point[0] - center[0]) / radius,
1743
+ (center[1] - point[1]) / radius,
1744
+ 0.0), dtype=numpy.float64)
1745
+ n = v[0]*v[0] + v[1]*v[1]
1746
+ if n > 1.0:
1747
+ v /= math.sqrt(n) # position outside of sphere
1748
+ else:
1749
+ v[2] = math.sqrt(1.0 - n)
1750
+ return v
1751
+
1752
+
1753
+ def arcball_constrain_to_axis(point, axis):
1754
+ """Return sphere point perpendicular to axis."""
1755
+ v = numpy.array(point, dtype=numpy.float64, copy=True)
1756
+ a = numpy.array(axis, dtype=numpy.float64, copy=True)
1757
+ v -= a * numpy.dot(a, v) # on plane
1758
+ n = vector_norm(v)
1759
+ if n > _EPS:
1760
+ if v[2] < 0.0:
1761
+ v *= -1.0
1762
+ v /= n
1763
+ return v
1764
+ if a[2] == 1.0:
1765
+ return numpy.array([1, 0, 0], dtype=numpy.float64)
1766
+ return unit_vector([-a[1], a[0], 0])
1767
+
1768
+
1769
+ def arcball_nearest_axis(point, axes):
1770
+ """Return axis, which arc is nearest to point."""
1771
+ point = numpy.array(point, dtype=numpy.float64, copy=False)
1772
+ nearest = None
1773
+ mx = -1.0
1774
+ for axis in axes:
1775
+ t = numpy.dot(arcball_constrain_to_axis(point, axis), point)
1776
+ if t > mx:
1777
+ nearest = axis
1778
+ mx = t
1779
+ return nearest
1780
+
1781
+
1782
+ # epsilon for testing whether a number is close to zero
1783
+ _EPS = numpy.finfo(float).eps * 4.0
1784
+
1785
+ # axis sequences for Euler angles
1786
+ _NEXT_AXIS = [1, 2, 0, 1]
1787
+
1788
+ # map axes strings to/from tuples of inner axis, parity, repetition, frame
1789
+ _AXES2TUPLE = {
1790
+ 'sxyz': (0, 0, 0, 0), 'sxyx': (0, 0, 1, 0), 'sxzy': (0, 1, 0, 0),
1791
+ 'sxzx': (0, 1, 1, 0), 'syzx': (1, 0, 0, 0), 'syzy': (1, 0, 1, 0),
1792
+ 'syxz': (1, 1, 0, 0), 'syxy': (1, 1, 1, 0), 'szxy': (2, 0, 0, 0),
1793
+ 'szxz': (2, 0, 1, 0), 'szyx': (2, 1, 0, 0), 'szyz': (2, 1, 1, 0),
1794
+ 'rzyx': (0, 0, 0, 1), 'rxyx': (0, 0, 1, 1), 'ryzx': (0, 1, 0, 1),
1795
+ 'rxzx': (0, 1, 1, 1), 'rxzy': (1, 0, 0, 1), 'ryzy': (1, 0, 1, 1),
1796
+ 'rzxy': (1, 1, 0, 1), 'ryxy': (1, 1, 1, 1), 'ryxz': (2, 0, 0, 1),
1797
+ 'rzxz': (2, 0, 1, 1), 'rxyz': (2, 1, 0, 1), 'rzyz': (2, 1, 1, 1)}
1798
+
1799
+ _TUPLE2AXES = dict((v, k) for k, v in _AXES2TUPLE.items())
1800
+
1801
+ # helper functions
1802
+
1803
+
1804
+ def vector_norm(data, axis=None, out=None):
1805
+ """Return length, i.e. eucledian norm, of ndarray along axis.
1806
+
1807
+ >>> v = numpy.random.random(3)
1808
+ >>> n = vector_norm(v)
1809
+ >>> numpy.allclose(n, numpy.linalg.norm(v))
1810
+ True
1811
+ >>> v = numpy.random.rand(6, 5, 3)
1812
+ >>> n = vector_norm(v, axis=-1)
1813
+ >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=2)))
1814
+ True
1815
+ >>> n = vector_norm(v, axis=1)
1816
+ >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=1)))
1817
+ True
1818
+ >>> v = numpy.random.rand(5, 4, 3)
1819
+ >>> n = numpy.empty((5, 3), dtype=numpy.float64)
1820
+ >>> vector_norm(v, axis=1, out=n)
1821
+ >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=1)))
1822
+ True
1823
+ >>> vector_norm([])
1824
+ 0.0
1825
+ >>> vector_norm([1.0])
1826
+ 1.0
1827
+
1828
+ """
1829
+ data = numpy.array(data, dtype=numpy.float64, copy=True)
1830
+ if out is None:
1831
+ if data.ndim == 1:
1832
+ return math.sqrt(numpy.dot(data, data))
1833
+ data *= data
1834
+ out = numpy.atleast_1d(numpy.sum(data, axis=axis))
1835
+ numpy.sqrt(out, out)
1836
+ return out
1837
+ else:
1838
+ data *= data
1839
+ numpy.sum(data, axis=axis, out=out)
1840
+ numpy.sqrt(out, out)
1841
+
1842
+
1843
+ def unit_vector(data, axis=None, out=None):
1844
+ """Return ndarray normalized by length, i.e. eucledian norm, along axis.
1845
+
1846
+ >>> v0 = numpy.random.random(3)
1847
+ >>> v1 = unit_vector(v0)
1848
+ >>> numpy.allclose(v1, v0 / numpy.linalg.norm(v0))
1849
+ True
1850
+ >>> v0 = numpy.random.rand(5, 4, 3)
1851
+ >>> v1 = unit_vector(v0, axis=-1)
1852
+ >>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=2)), 2)
1853
+ >>> numpy.allclose(v1, v2)
1854
+ True
1855
+ >>> v1 = unit_vector(v0, axis=1)
1856
+ >>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=1)), 1)
1857
+ >>> numpy.allclose(v1, v2)
1858
+ True
1859
+ >>> v1 = numpy.empty((5, 4, 3), dtype=numpy.float64)
1860
+ >>> unit_vector(v0, axis=1, out=v1)
1861
+ >>> numpy.allclose(v1, v2)
1862
+ True
1863
+ >>> list(unit_vector([]))
1864
+ []
1865
+ >>> list(unit_vector([1.0]))
1866
+ [1.0]
1867
+
1868
+ """
1869
+ if out is None:
1870
+ data = numpy.array(data, dtype=numpy.float64, copy=True)
1871
+ if data.ndim == 1:
1872
+ data /= math.sqrt(numpy.dot(data, data))
1873
+ return data
1874
+ else:
1875
+ if out is not data:
1876
+ out[:] = numpy.array(data, copy=False)
1877
+ data = out
1878
+ length = numpy.atleast_1d(numpy.sum(data*data, axis))
1879
+ numpy.sqrt(length, length)
1880
+ if axis is not None:
1881
+ length = numpy.expand_dims(length, axis)
1882
+ data /= length
1883
+ if out is None:
1884
+ return data
1885
+
1886
+
1887
+ def random_vector(size):
1888
+ """Return array of random doubles in the half-open interval [0.0, 1.0).
1889
+
1890
+ >>> v = random_vector(10000)
1891
+ >>> numpy.all(v >= 0.0) and numpy.all(v < 1.0)
1892
+ True
1893
+ >>> v0 = random_vector(10)
1894
+ >>> v1 = random_vector(10)
1895
+ >>> numpy.any(v0 == v1)
1896
+ False
1897
+
1898
+ """
1899
+ return numpy.random.random(size)
1900
+
1901
+
1902
+ def inverse_matrix(matrix):
1903
+ """Return inverse of square transformation matrix.
1904
+
1905
+ >>> M0 = random_rotation_matrix()
1906
+ >>> M1 = inverse_matrix(M0.T)
1907
+ >>> numpy.allclose(M1, numpy.linalg.inv(M0.T))
1908
+ True
1909
+ >>> for size in range(1, 7):
1910
+ ... M0 = numpy.random.rand(size, size)
1911
+ ... M1 = inverse_matrix(M0)
1912
+ ... if not numpy.allclose(M1, numpy.linalg.inv(M0)): print size
1913
+
1914
+ """
1915
+ return numpy.linalg.inv(matrix)
1916
+
1917
+
1918
+ def concatenate_matrices(*matrices):
1919
+ """Return concatenation of series of transformation matrices.
1920
+
1921
+ >>> M = numpy.random.rand(16).reshape((4, 4)) - 0.5
1922
+ >>> numpy.allclose(M, concatenate_matrices(M))
1923
+ True
1924
+ >>> numpy.allclose(numpy.dot(M, M.T), concatenate_matrices(M, M.T))
1925
+ True
1926
+
1927
+ """
1928
+ M = numpy.identity(4)
1929
+ for i in matrices:
1930
+ M = numpy.dot(M, i)
1931
+ return M
1932
+
1933
+
1934
+ def is_same_transform(matrix0, matrix1):
1935
+ """Return True if two matrices perform same transformation.
1936
+
1937
+ >>> is_same_transform(numpy.identity(4), numpy.identity(4))
1938
+ True
1939
+ >>> is_same_transform(numpy.identity(4), random_rotation_matrix())
1940
+ False
1941
+
1942
+ """
1943
+ matrix0 = numpy.array(matrix0, dtype=numpy.float64, copy=True)
1944
+ matrix0 /= matrix0[3, 3]
1945
+ matrix1 = numpy.array(matrix1, dtype=numpy.float64, copy=True)
1946
+ matrix1 /= matrix1[3, 3]
1947
+ return numpy.allclose(matrix0, matrix1)
1948
+
1949
+
1950
+ def _import_module(module_name, warn=True, prefix='_py_', ignore='_'):
1951
+ """Try import all public attributes from module into global namespace.
1952
+
1953
+ Existing attributes with name clashes are renamed with prefix.
1954
+ Attributes starting with underscore are ignored by default.
1955
+
1956
+ Return True on successful import.
1957
+
1958
+ """
1959
+ try:
1960
+ module = __import__(module_name)
1961
+ except ImportError:
1962
+ if warn:
1963
+ warnings.warn("Failed to import module " + module_name)
1964
+ else:
1965
+ for attr in dir(module):
1966
+ if ignore and attr.startswith(ignore):
1967
+ continue
1968
+ if prefix:
1969
+ if attr in globals():
1970
+ globals()[prefix + attr] = globals()[attr]
1971
+ elif warn:
1972
+ warnings.warn("No Python implementation of " + attr)
1973
+ globals()[attr] = getattr(module, attr)
1974
+ return True
utils/utils_poses/align_traj.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from utils.utils_poses.ATE.align_utils import alignTrajectory
5
+ from utils.utils_poses.lie_group_helper import SO3_to_quat, convert3x4_4x4
6
+
7
+
8
+ def pts_dist_max(pts):
9
+ """
10
+ :param pts: (N, 3) torch or np
11
+ :return: scalar
12
+ """
13
+ if torch.is_tensor(pts):
14
+ dist = pts.unsqueeze(0) - pts.unsqueeze(1) # (1, N, 3) - (N, 1, 3) -> (N, N, 3)
15
+ dist = dist[0] # (N, 3)
16
+ dist = dist.norm(dim=1) # (N, )
17
+ max_dist = dist.max()
18
+ else:
19
+ dist = pts[None, :, :] - pts[:, None, :] # (1, N, 3) - (N, 1, 3) -> (N, N, 3)
20
+ dist = dist[0] # (N, 3)
21
+ dist = np.linalg.norm(dist, axis=1) # (N, )
22
+ max_dist = dist.max()
23
+ return max_dist
24
+
25
+
26
+ def align_ate_c2b_use_a2b(traj_a, traj_b, traj_c=None, method='sim3'):
27
+ """Align c to b using the sim3 from a to b.
28
+ :param traj_a: (N0, 3/4, 4) torch tensor
29
+ :param traj_b: (N0, 3/4, 4) torch tensor
30
+ :param traj_c: None or (N1, 3/4, 4) torch tensor
31
+ :return: (N1, 4, 4) torch tensor
32
+ """
33
+ device = traj_a.device
34
+ if traj_c is None:
35
+ traj_c = traj_a.clone()
36
+
37
+ traj_a = traj_a.float().cpu().numpy()
38
+ traj_b = traj_b.float().cpu().numpy()
39
+ traj_c = traj_c.float().cpu().numpy()
40
+
41
+ R_a = traj_a[:, :3, :3] # (N0, 3, 3)
42
+ t_a = traj_a[:, :3, 3] # (N0, 3)
43
+ quat_a = SO3_to_quat(R_a) # (N0, 4)
44
+
45
+ R_b = traj_b[:, :3, :3] # (N0, 3, 3)
46
+ t_b = traj_b[:, :3, 3] # (N0, 3)
47
+ quat_b = SO3_to_quat(R_b) # (N0, 4)
48
+
49
+ # This function works in quaternion.
50
+ # scalar, (3, 3), (3, ) gt = R * s * est + t.
51
+ s, R, t = alignTrajectory(t_a, t_b, quat_a, quat_b, method=method)
52
+
53
+ # reshape tensors
54
+ R = R[None, :, :].astype(np.float32) # (1, 3, 3)
55
+ t = t[None, :, None].astype(np.float32) # (1, 3, 1)
56
+ s = float(s)
57
+
58
+ R_c = traj_c[:, :3, :3] # (N1, 3, 3)
59
+ t_c = traj_c[:, :3, 3:4] # (N1, 3, 1)
60
+
61
+ R_c_aligned = R @ R_c # (N1, 3, 3)
62
+ t_c_aligned = s * (R @ t_c) + t # (N1, 3, 1)
63
+ traj_c_aligned = np.concatenate([R_c_aligned, t_c_aligned], axis=2) # (N1, 3, 4)
64
+
65
+ # append the last row
66
+ traj_c_aligned = convert3x4_4x4(traj_c_aligned) # (N1, 4, 4)
67
+
68
+ traj_c_aligned = torch.from_numpy(traj_c_aligned).to(device)
69
+ return traj_c_aligned # (N1, 4, 4)
70
+
71
+
72
+
73
+ def align_scale_c2b_use_a2b(traj_a, traj_b, traj_c=None):
74
+ '''Scale c to b using the scale from a to b.
75
+ :param traj_a: (N0, 3/4, 4) torch tensor
76
+ :param traj_b: (N0, 3/4, 4) torch tensor
77
+ :param traj_c: None or (N1, 3/4, 4) torch tensor
78
+ :return:
79
+ scaled_traj_c (N1, 4, 4) torch tensor
80
+ scale scalar
81
+ '''
82
+ if traj_c is None:
83
+ traj_c = traj_a.clone()
84
+
85
+ t_a = traj_a[:, :3, 3] # (N, 3)
86
+ t_b = traj_b[:, :3, 3] # (N, 3)
87
+
88
+ # scale estimated poses to colmap scale
89
+ # s_a2b: a*s ~ b
90
+ scale_a2b = pts_dist_max(t_b) / pts_dist_max(t_a)
91
+
92
+ traj_c[:, :3, 3] *= scale_a2b
93
+
94
+ if traj_c.shape[1] == 3:
95
+ traj_c = convert3x4_4x4(traj_c) # (N, 4, 4)
96
+
97
+ return traj_c, scale_a2b # (N, 4, 4)
utils/utils_poses/comp_ate.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+
4
+ import utils.utils_poses.ATE.trajectory_utils as tu
5
+ import utils.utils_poses.ATE.transformations as tf
6
+ def rotation_error(pose_error):
7
+ """Compute rotation error
8
+ Args:
9
+ pose_error (4x4 array): relative pose error
10
+ Returns:
11
+ rot_error (float): rotation error
12
+ """
13
+ a = pose_error[0, 0]
14
+ b = pose_error[1, 1]
15
+ c = pose_error[2, 2]
16
+ d = 0.5*(a+b+c-1.0)
17
+ rot_error = np.arccos(max(min(d, 1.0), -1.0))
18
+ return rot_error
19
+
20
+ def translation_error(pose_error):
21
+ """Compute translation error
22
+ Args:
23
+ pose_error (4x4 array): relative pose error
24
+ Returns:
25
+ trans_error (float): translation error
26
+ """
27
+ dx = pose_error[0, 3]
28
+ dy = pose_error[1, 3]
29
+ dz = pose_error[2, 3]
30
+ trans_error = np.sqrt(dx**2+dy**2+dz**2)
31
+ return trans_error
32
+
33
+ def compute_rpe(gt, pred):
34
+ trans_errors = []
35
+ rot_errors = []
36
+ for i in range(len(gt)-1):
37
+ gt1 = gt[i]
38
+ gt2 = gt[i+1]
39
+ gt_rel = np.linalg.inv(gt1) @ gt2
40
+
41
+ pred1 = pred[i]
42
+ pred2 = pred[i+1]
43
+ pred_rel = np.linalg.inv(pred1) @ pred2
44
+ rel_err = np.linalg.inv(gt_rel) @ pred_rel
45
+
46
+ trans_errors.append(translation_error(rel_err))
47
+ rot_errors.append(rotation_error(rel_err))
48
+ rpe_trans = np.mean(np.asarray(trans_errors))
49
+ rpe_rot = np.mean(np.asarray(rot_errors))
50
+ return rpe_trans, rpe_rot
51
+
52
+ def compute_ATE(gt, pred):
53
+ """Compute RMSE of ATE
54
+ Args:
55
+ gt: ground-truth poses
56
+ pred: predicted poses
57
+ """
58
+ errors = []
59
+
60
+ for i in range(len(pred)):
61
+ # cur_gt = np.linalg.inv(gt_0) @ gt[i]
62
+ cur_gt = gt[i]
63
+ gt_xyz = cur_gt[:3, 3]
64
+
65
+ # cur_pred = np.linalg.inv(pred_0) @ pred[i]
66
+ cur_pred = pred[i]
67
+ pred_xyz = cur_pred[:3, 3]
68
+
69
+ align_err = gt_xyz - pred_xyz
70
+
71
+ errors.append(np.sqrt(np.sum(align_err ** 2)))
72
+ ate = np.sqrt(np.mean(np.asarray(errors) ** 2))
73
+ return ate
74
+
utils/utils_poses/lie_group_helper.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from scipy.spatial.transform import Rotation as RotLib
4
+
5
+
6
+ def SO3_to_quat(R):
7
+ """
8
+ :param R: (N, 3, 3) or (3, 3) np
9
+ :return: (N, 4, ) or (4, ) np
10
+ """
11
+ x = RotLib.from_matrix(R)
12
+ quat = x.as_quat()
13
+ return quat
14
+
15
+
16
+ def quat_to_SO3(quat):
17
+ """
18
+ :param quat: (N, 4, ) or (4, ) np
19
+ :return: (N, 3, 3) or (3, 3) np
20
+ """
21
+ x = RotLib.from_quat(quat)
22
+ R = x.as_matrix()
23
+ return R
24
+
25
+
26
+ def convert3x4_4x4(input):
27
+ """
28
+ :param input: (N, 3, 4) or (3, 4) torch or np
29
+ :return: (N, 4, 4) or (4, 4) torch or np
30
+ """
31
+ if torch.is_tensor(input):
32
+ if len(input.shape) == 3:
33
+ output = torch.cat([input, torch.zeros_like(input[:, 0:1])], dim=1) # (N, 4, 4)
34
+ output[:, 3, 3] = 1.0
35
+ else:
36
+ output = torch.cat([input, torch.tensor([[0,0,0,1]], dtype=input.dtype, device=input.device)], dim=0) # (4, 4)
37
+ else:
38
+ if len(input.shape) == 3:
39
+ output = np.concatenate([input, np.zeros_like(input[:, 0:1])], axis=1) # (N, 4, 4)
40
+ output[:, 3, 3] = 1.0
41
+ else:
42
+ output = np.concatenate([input, np.array([[0,0,0,1]], dtype=input.dtype)], axis=0) # (4, 4)
43
+ output[3, 3] = 1.0
44
+ return output
45
+
46
+
47
+ def vec2skew(v):
48
+ """
49
+ :param v: (3, ) torch tensor
50
+ :return: (3, 3)
51
+ """
52
+ zero = torch.zeros(1, dtype=torch.float32, device=v.device)
53
+ skew_v0 = torch.cat([ zero, -v[2:3], v[1:2]]) # (3, 1)
54
+ skew_v1 = torch.cat([ v[2:3], zero, -v[0:1]])
55
+ skew_v2 = torch.cat([-v[1:2], v[0:1], zero])
56
+ skew_v = torch.stack([skew_v0, skew_v1, skew_v2], dim=0) # (3, 3)
57
+ return skew_v # (3, 3)
58
+
59
+
60
+ def Exp(r):
61
+ """so(3) vector to SO(3) matrix
62
+ :param r: (3, ) axis-angle, torch tensor
63
+ :return: (3, 3)
64
+ """
65
+ skew_r = vec2skew(r) # (3, 3)
66
+ norm_r = r.norm() + 1e-15
67
+ eye = torch.eye(3, dtype=torch.float32, device=r.device)
68
+ R = eye + (torch.sin(norm_r) / norm_r) * skew_r + ((1 - torch.cos(norm_r)) / norm_r**2) * (skew_r @ skew_r)
69
+ return R
70
+
71
+
72
+ def make_c2w(r, t):
73
+ """
74
+ :param r: (3, ) axis-angle torch tensor
75
+ :param t: (3, ) translation vector torch tensor
76
+ :return: (4, 4)
77
+ """
78
+ R = Exp(r) # (3, 3)
79
+ c2w = torch.cat([R, t.unsqueeze(1)], dim=1) # (3, 4)
80
+ c2w = convert3x4_4x4(c2w) # (4, 4)
81
+ return c2w
utils/utils_poses/relative_pose.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def compute_relative_world_to_camera(R1, t1, R2, t2):
6
+ zero_row = torch.tensor([[0, 0, 0, 1]], dtype=torch.float32, device="cuda") #, requires_grad=True
7
+ E1_inv = torch.cat([torch.transpose(R1, 0, 1), -torch.transpose(R1, 0, 1) @ t1.reshape(-1, 1)], dim=1)
8
+ E1_inv = torch.cat([E1_inv, zero_row], dim=0)
9
+ E2 = torch.cat([R2, -R2 @ t2.reshape(-1, 1)], dim=1)
10
+ E2 = torch.cat([E2, zero_row], dim=0)
11
+
12
+ # Compute relative transformation
13
+ E_rel = E2 @ E1_inv
14
+
15
+ # # Extract rotation and translation
16
+ # R_rel = E_rel[:3, :3]
17
+ # t_rel = E_rel[:3, 3]
18
+ # E_rel = torch.cat([E_rel, zero_row], dim=0)
19
+
20
+ return E_rel
utils/utils_poses/vis_cam_traj.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is modified from NeRF++: https://github.com/Kai-46/nerfplusplus
2
+
3
+ import numpy as np
4
+
5
+ try:
6
+ import open3d as o3d
7
+ except ImportError:
8
+ pass
9
+
10
+
11
+ def frustums2lineset(frustums):
12
+ N = len(frustums)
13
+ merged_points = np.zeros((N*5, 3)) # 5 vertices per frustum
14
+ merged_lines = np.zeros((N*8, 2)) # 8 lines per frustum
15
+ merged_colors = np.zeros((N*8, 3)) # each line gets a color
16
+
17
+ for i, (frustum_points, frustum_lines, frustum_colors) in enumerate(frustums):
18
+ merged_points[i*5:(i+1)*5, :] = frustum_points
19
+ merged_lines[i*8:(i+1)*8, :] = frustum_lines + i*5
20
+ merged_colors[i*8:(i+1)*8, :] = frustum_colors
21
+
22
+ lineset = o3d.geometry.LineSet()
23
+ lineset.points = o3d.utility.Vector3dVector(merged_points)
24
+ lineset.lines = o3d.utility.Vector2iVector(merged_lines)
25
+ lineset.colors = o3d.utility.Vector3dVector(merged_colors)
26
+
27
+ return lineset
28
+
29
+
30
+ def get_camera_frustum_opengl_coord(H, W, fx, fy, W2C, frustum_length=0.5, color=np.array([0., 1., 0.])):
31
+ '''X right, Y up, Z backward to the observer.
32
+ :param H, W:
33
+ :param fx, fy:
34
+ :param W2C: (4, 4) matrix
35
+ :param frustum_length: scalar: scale the frustum
36
+ :param color: (3,) list, frustum line color
37
+ :return:
38
+ frustum_points: (5, 3) frustum points in world coordinate
39
+ frustum_lines: (8, 2) 8 lines connect 5 frustum points, specified in line start/end index.
40
+ frustum_colors: (8, 3) colors for 8 lines.
41
+ '''
42
+ hfov = np.rad2deg(np.arctan(W / 2. / fx) * 2.)
43
+ vfov = np.rad2deg(np.arctan(H / 2. / fy) * 2.)
44
+ half_w = frustum_length * np.tan(np.deg2rad(hfov / 2.))
45
+ half_h = frustum_length * np.tan(np.deg2rad(vfov / 2.))
46
+
47
+ # build view frustum in camera space in homogenous coordinate (5, 4)
48
+ frustum_points = np.array([[0., 0., 0., 1.0], # frustum origin
49
+ [-half_w, half_h, -frustum_length, 1.0], # top-left image corner
50
+ [half_w, half_h, -frustum_length, 1.0], # top-right image corner
51
+ [half_w, -half_h, -frustum_length, 1.0], # bottom-right image corner
52
+ [-half_w, -half_h, -frustum_length, 1.0]]) # bottom-left image corner
53
+ frustum_lines = np.array([[0, i] for i in range(1, 5)] + [[i, (i+1)] for i in range(1, 4)] + [[4, 1]]) # (8, 2)
54
+ frustum_colors = np.tile(color.reshape((1, 3)), (frustum_lines.shape[0], 1)) # (8, 3)
55
+
56
+ # transform view frustum from camera space to world space
57
+ C2W = np.linalg.inv(W2C)
58
+ frustum_points = np.matmul(C2W, frustum_points.T).T # (5, 4)
59
+ frustum_points = frustum_points[:, :3] / frustum_points[:, 3:4] # (5, 3) remove homogenous coordinate
60
+ return frustum_points, frustum_lines, frustum_colors
61
+
62
+ def get_camera_frustum_opencv_coord(H, W, fx, fy, W2C, frustum_length=0.5, color=np.array([0., 1., 0.])):
63
+ '''X right, Y up, Z backward to the observer.
64
+ :param H, W:
65
+ :param fx, fy:
66
+ :param W2C: (4, 4) matrix
67
+ :param frustum_length: scalar: scale the frustum
68
+ :param color: (3,) list, frustum line color
69
+ :return:
70
+ frustum_points: (5, 3) frustum points in world coordinate
71
+ frustum_lines: (8, 2) 8 lines connect 5 frustum points, specified in line start/end index.
72
+ frustum_colors: (8, 3) colors for 8 lines.
73
+ '''
74
+ hfov = np.rad2deg(np.arctan(W / 2. / fx) * 2.)
75
+ vfov = np.rad2deg(np.arctan(H / 2. / fy) * 2.)
76
+ half_w = frustum_length * np.tan(np.deg2rad(hfov / 2.))
77
+ half_h = frustum_length * np.tan(np.deg2rad(vfov / 2.))
78
+
79
+ # build view frustum in camera space in homogenous coordinate (5, 4)
80
+ frustum_points = np.array([[0., 0., 0., 1.0], # frustum origin
81
+ [-half_w, -half_h, frustum_length, 1.0], # top-left image corner
82
+ [ half_w, -half_h, frustum_length, 1.0], # top-right image corner
83
+ [ half_w, half_h, frustum_length, 1.0], # bottom-right image corner
84
+ [-half_w, +half_h, frustum_length, 1.0]]) # bottom-left image corner
85
+ frustum_lines = np.array([[0, i] for i in range(1, 5)] + [[i, (i+1)] for i in range(1, 4)] + [[4, 1]]) # (8, 2)
86
+ frustum_colors = np.tile(color.reshape((1, 3)), (frustum_lines.shape[0], 1)) # (8, 3)
87
+
88
+ # transform view frustum from camera space to world space
89
+ C2W = np.linalg.inv(W2C)
90
+ frustum_points = np.matmul(C2W, frustum_points.T).T # (5, 4)
91
+ frustum_points = frustum_points[:, :3] / frustum_points[:, 3:4] # (5, 3) remove homogenous coordinate
92
+ return frustum_points, frustum_lines, frustum_colors
93
+
94
+
95
+
96
+ def draw_camera_frustum_geometry(c2ws, H, W, fx=600.0, fy=600.0, frustum_length=0.5,
97
+ color=np.array([29.0, 53.0, 87.0])/255.0, draw_now=False, coord='opengl'):
98
+ '''
99
+ :param c2ws: (N, 4, 4) np.array
100
+ :param H: scalar
101
+ :param W: scalar
102
+ :param fx: scalar
103
+ :param fy: scalar
104
+ :param frustum_length: scalar
105
+ :param color: None or (N, 3) or (3, ) or (1, 3) or (3, 1) np array
106
+ :param draw_now: True/False call o3d vis now
107
+ :return:
108
+ '''
109
+ N = c2ws.shape[0]
110
+
111
+ num_ele = color.flatten().shape[0]
112
+ if num_ele == 3:
113
+ color = color.reshape(1, 3)
114
+ color = np.tile(color, (N, 1))
115
+
116
+ frustum_list = []
117
+ if coord == 'opengl':
118
+ for i in range(N):
119
+ frustum_list.append(get_camera_frustum_opengl_coord(H, W, fx, fy,
120
+ W2C=np.linalg.inv(c2ws[i]),
121
+ frustum_length=frustum_length,
122
+ color=color[i]))
123
+ elif coord == 'opencv':
124
+ for i in range(N):
125
+ frustum_list.append(get_camera_frustum_opencv_coord(H, W, fx, fy,
126
+ W2C=np.linalg.inv(c2ws[i]),
127
+ frustum_length=frustum_length,
128
+ color=color[i]))
129
+ else:
130
+ print('Undefined coordinate system. Exit')
131
+ exit()
132
+
133
+ frustums_geometry = frustums2lineset(frustum_list)
134
+
135
+ if draw_now:
136
+ o3d.visualization.draw_geometries([frustums_geometry])
137
+
138
+ return frustums_geometry # this is an o3d geometry object.
utils/utils_poses/vis_pose_utils.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import matplotlib
3
+ matplotlib.use('Agg')
4
+
5
+ from matplotlib import pyplot as plt
6
+ plt.ioff()
7
+
8
+ import copy
9
+ from evo.core.trajectory import PosePath3D, PoseTrajectory3D
10
+ from evo.main_ape import ape
11
+ from evo.tools import plot
12
+ from evo.core import sync
13
+ from evo.tools import file_interface
14
+ from evo.core import metrics
15
+ import evo
16
+ import torch
17
+ import numpy as np
18
+ from scipy.spatial.transform import Slerp
19
+ from scipy.spatial.transform import Rotation as R
20
+ import scipy.interpolate as si
21
+
22
+
23
+ def interp_poses(c2ws, N_views):
24
+ N_inputs = c2ws.shape[0]
25
+ trans = c2ws[:, :3, 3:].permute(2, 1, 0)
26
+ rots = c2ws[:, :3, :3]
27
+ render_poses = []
28
+ rots = R.from_matrix(rots)
29
+ slerp = Slerp(np.linspace(0, 1, N_inputs), rots)
30
+ interp_rots = torch.tensor(
31
+ slerp(np.linspace(0, 1, N_views)).as_matrix().astype(np.float32))
32
+ interp_trans = torch.nn.functional.interpolate(
33
+ trans, size=N_views, mode='linear').permute(2, 1, 0)
34
+ render_poses = torch.cat([interp_rots, interp_trans], dim=2)
35
+ render_poses = convert3x4_4x4(render_poses)
36
+ return render_poses
37
+
38
+
39
+ def interp_poses_bspline(c2ws, N_novel_imgs, input_times, degree):
40
+ target_trans = torch.tensor(scipy_bspline(
41
+ c2ws[:, :3, 3], n=N_novel_imgs, degree=degree, periodic=False).astype(np.float32)).unsqueeze(2)
42
+ rots = R.from_matrix(c2ws[:, :3, :3])
43
+ slerp = Slerp(input_times, rots)
44
+ target_times = np.linspace(input_times[0], input_times[-1], N_novel_imgs)
45
+ target_rots = torch.tensor(
46
+ slerp(target_times).as_matrix().astype(np.float32))
47
+ target_poses = torch.cat([target_rots, target_trans], dim=2)
48
+ target_poses = convert3x4_4x4(target_poses)
49
+ return target_poses
50
+
51
+
52
+ def poses_avg(poses):
53
+
54
+ hwf = poses[0, :3, -1:]
55
+
56
+ center = poses[:, :3, 3].mean(0)
57
+ vec2 = normalize(poses[:, :3, 2].sum(0))
58
+ up = poses[:, :3, 1].sum(0)
59
+ c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)
60
+
61
+ return c2w
62
+
63
+
64
+ def normalize(v):
65
+ """Normalize a vector."""
66
+ return v / np.linalg.norm(v)
67
+
68
+
69
+ def viewmatrix(z, up, pos):
70
+ vec2 = normalize(z)
71
+ vec1_avg = up
72
+ vec0 = normalize(np.cross(vec1_avg, vec2))
73
+ vec1 = normalize(np.cross(vec2, vec0))
74
+ m = np.stack([vec0, vec1, vec2, pos], 1)
75
+ return m
76
+
77
+
78
+ def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
79
+ render_poses = []
80
+ rads = np.array(list(rads) + [1.])
81
+ hwf = c2w[:, 4:5]
82
+
83
+ for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]:
84
+ # c = np.dot(c2w[:3,:4], np.array([0.7*np.cos(theta) , -0.3*np.sin(theta) , -np.sin(theta*zrate) *0.1, 1.]) * rads)
85
+ # c = np.dot(c2w[:3,:4], np.array([0.3*np.cos(theta) , -0.3*np.sin(theta) , -np.sin(theta*zrate) *0.01, 1.]) * rads)
86
+ c = np.dot(c2w[:3, :4], np.array(
87
+ [0.2*np.cos(theta), -0.2*np.sin(theta), -np.sin(theta*zrate) * 0.1, 1.]) * rads)
88
+ z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))
89
+ render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1))
90
+ return render_poses
91
+
92
+
93
+ def scipy_bspline(cv, n=100, degree=3, periodic=False):
94
+ """ Calculate n samples on a bspline
95
+
96
+ cv : Array ov control vertices
97
+ n : Number of samples to return
98
+ degree: Curve degree
99
+ periodic: True - Curve is closed
100
+ """
101
+ cv = np.asarray(cv)
102
+ count = cv.shape[0]
103
+
104
+ # Closed curve
105
+ if periodic:
106
+ kv = np.arange(-degree, count+degree+1)
107
+ factor, fraction = divmod(count+degree+1, count)
108
+ cv = np.roll(np.concatenate(
109
+ (cv,) * factor + (cv[:fraction],)), -1, axis=0)
110
+ degree = np.clip(degree, 1, degree)
111
+
112
+ # Opened curve
113
+ else:
114
+ degree = np.clip(degree, 1, count-1)
115
+ kv = np.clip(np.arange(count+degree+1)-degree, 0, count-degree)
116
+
117
+ # Return samples
118
+ max_param = count - (degree * (1-periodic))
119
+ spl = si.BSpline(kv, cv, degree)
120
+ return spl(np.linspace(0, max_param, n))
121
+
122
+
123
+ def generate_spiral_nerf(learned_poses, bds, N_novel_views, hwf):
124
+ learned_poses_ = np.concatenate((learned_poses[:, :3, :4].detach(
125
+ ).cpu().numpy(), hwf[:len(learned_poses)]), axis=-1)
126
+ c2w = poses_avg(learned_poses_)
127
+ print('recentered', c2w.shape)
128
+ # Get spiral
129
+ # Get average pose
130
+ up = normalize(learned_poses_[:, :3, 1].sum(0))
131
+ # Find a reasonable "focus depth" for this dataset
132
+
133
+ close_depth, inf_depth = bds.min()*.9, bds.max()*5.
134
+ dt = .75
135
+ mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth))
136
+ focal = mean_dz
137
+
138
+ # Get radii for spiral path
139
+ shrink_factor = .8
140
+ zdelta = close_depth * .2
141
+ tt = learned_poses_[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T
142
+ rads = np.percentile(np.abs(tt), 90, 0)
143
+ c2w_path = c2w
144
+ N_rots = 2
145
+ c2ws = render_path_spiral(
146
+ c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_novel_views)
147
+ c2ws = torch.tensor(np.stack(c2ws).astype(np.float32))
148
+ c2ws = c2ws[:, :3, :4]
149
+ c2ws = convert3x4_4x4(c2ws)
150
+ return c2ws
151
+
152
+
153
+ def convert3x4_4x4(input):
154
+ """
155
+ :param input: (N, 3, 4) or (3, 4) torch or np
156
+ :return: (N, 4, 4) or (4, 4) torch or np
157
+ """
158
+ if torch.is_tensor(input):
159
+ if len(input.shape) == 3:
160
+ output = torch.cat([input, torch.zeros_like(
161
+ input[:, 0:1])], dim=1) # (N, 4, 4)
162
+ output[:, 3, 3] = 1.0
163
+ else:
164
+ output = torch.cat([input, torch.tensor(
165
+ [[0, 0, 0, 1]], dtype=input.dtype, device=input.device)], dim=0) # (4, 4)
166
+ else:
167
+ if len(input.shape) == 3:
168
+ output = np.concatenate(
169
+ [input, np.zeros_like(input[:, 0:1])], axis=1) # (N, 4, 4)
170
+ output[:, 3, 3] = 1.0
171
+ else:
172
+ output = np.concatenate(
173
+ [input, np.array([[0, 0, 0, 1]], dtype=input.dtype)], axis=0) # (4, 4)
174
+ output[3, 3] = 1.0
175
+ return output
176
+
177
+
178
+ plt.rc('legend', fontsize=20) # using a named size
179
+
180
+
181
+ def plot_pose(ref_poses, est_poses, output_path, args, vid=False):
182
+ ref_poses = [pose for pose in ref_poses]
183
+ if isinstance(est_poses, dict):
184
+ est_poses = [pose for k, pose in est_poses.items()]
185
+ else:
186
+ est_poses = [pose for pose in est_poses]
187
+ traj_ref = PosePath3D(poses_se3=ref_poses)
188
+ traj_est = PosePath3D(poses_se3=est_poses)
189
+ traj_est_aligned = copy.deepcopy(traj_est)
190
+ traj_est_aligned.align(traj_ref, correct_scale=True,
191
+ correct_only_scale=False)
192
+ if vid:
193
+ for p_idx in range(len(ref_poses)):
194
+ fig = plt.figure()
195
+ current_est_aligned = traj_est_aligned.poses_se3[:p_idx+1]
196
+ current_ref = traj_ref.poses_se3[:p_idx+1]
197
+ current_est_aligned = PosePath3D(poses_se3=current_est_aligned)
198
+ current_ref = PosePath3D(poses_se3=current_ref)
199
+ traj_by_label = {
200
+ # "estimate (not aligned)": traj_est,
201
+ "Ours (aligned)": current_est_aligned,
202
+ "Ground-truth": current_ref
203
+ }
204
+ plot_mode = plot.PlotMode.xyz
205
+ # ax = plot.prepare_axis(fig, plot_mode, 111)
206
+ ax = fig.add_subplot(111, projection="3d")
207
+ ax.xaxis.set_tick_params(labelbottom=False)
208
+ ax.yaxis.set_tick_params(labelleft=False)
209
+ ax.zaxis.set_tick_params(labelleft=False)
210
+ colors = ['r', 'b']
211
+ styles = ['-', '--']
212
+
213
+ for idx, (label, traj) in enumerate(traj_by_label.items()):
214
+ plot.traj(ax, plot_mode, traj,
215
+ styles[idx], colors[idx], label)
216
+ # break
217
+ # plot.trajectories(fig, traj_by_label, plot.PlotMode.xyz)
218
+ ax.view_init(elev=10., azim=45)
219
+ plt.tight_layout()
220
+ os.makedirs(os.path.join(os.path.dirname(
221
+ output_path), 'pose_vid'), exist_ok=True)
222
+ pose_vis_path = os.path.join(os.path.dirname(
223
+ output_path), 'pose_vid', 'pose_vis_{:03d}.png'.format(p_idx))
224
+ print(pose_vis_path)
225
+ fig.savefig(pose_vis_path)
226
+
227
+ # else:
228
+
229
+ fig = plt.figure()
230
+ fig.patch.set_facecolor('white') # 把背景设置为纯白色
231
+ traj_by_label = {
232
+ # "estimate (not aligned)": traj_est,
233
+
234
+ "Ours (aligned)": traj_est_aligned,
235
+ # "NoPe-NeRF (aligned)": traj_est_aligned,
236
+ # "CF-3DGS (aligned)": traj_est_aligned,
237
+ # "NeRFmm (aligned)": traj_est_aligned,
238
+ # args.method + " (aligned)": traj_est_aligned,
239
+ "COLMAP (GT)": traj_ref
240
+ # "Ground-truth": traj_ref
241
+ }
242
+ plot_mode = plot.PlotMode.xyz
243
+ # ax = plot.prepare_axis(fig, plot_mode, 111)
244
+ ax = fig.add_subplot(111, projection="3d")
245
+ ax.set_facecolor('white') # 把子图设置为纯白色
246
+ ax.xaxis.set_tick_params(labelbottom=True)
247
+ ax.yaxis.set_tick_params(labelleft=True)
248
+ ax.zaxis.set_tick_params(labelleft=True)
249
+ colors = ['#2c9e38', '#d12920'] #
250
+ # colors = ['#2c9e38', '#a72126'] #
251
+
252
+ # colors = ['r', 'b']
253
+ styles = ['-', '--']
254
+
255
+ for idx, (label, traj) in enumerate(traj_by_label.items()):
256
+ plot.traj(ax, plot_mode, traj,
257
+ styles[idx], colors[idx], label)
258
+ # break
259
+ # plot.trajectories(fig, traj_by_label, plot.PlotMode.xyz)
260
+ ax.view_init(elev=30., azim=45)
261
+ # ax.view_init(elev=10., azim=45)
262
+ plt.tight_layout()
263
+ pose_vis_path = output_path / f'pose_vis.png'
264
+ # pose_vis_path = os.path.join(os.path.dirname(output_path), f'pose_vis_{args.method}_{args.scene}.png')
265
+ fig.savefig(pose_vis_path)
266
+
267
+ # path_parts = args.pose_path.split('/')
268
+ # tmp_vis_path = '/'.join(path_parts[:-1]) + '/all_vis'
269
+ # tmp_vis_path2 = os.path.join(tmp_vis_path, f'pose_vis_{args.method}_{args.scene}.png')
270
+ # fig.savefig(tmp_vis_path2)